Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,065 Bytes
7bb080d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# --- START OF FILE app.py ---
import sys
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from dotenv import load_dotenv
# --- FIX: Add project root to Python's path ---
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)
# --- Updated Spaces import for Zero-GPU compatibility ---
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces()
# --- Step 1: Hugging Face Authentication ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.")
print("--- Logging in to Hugging Face Hub ---")
login(token=HF_TOKEN)
# --- Step 2: Initialize Model and Tokenizer (Load Once on Startup) ---
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN"
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---")
# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU detected. Using CUDA.")
else:
device = torch.device("cpu")
print("No GPU detected. Using CPU.")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"--- Using dtype: {dtype} ---")
print(f"--- Loading tokenizer from Hub: {MODEL_NAME} ---")
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
# Define the special token from your model's tokenizer
SPECIAL_MARKER = "<|LOC_0|>"
print(f"--- Using special marker for overlap: {SPECIAL_MARKER} ---")
except Exception as e:
raise RuntimeError(f"FATAL: Could not load tokenizer from the Hub. Error: {e}")
print(f"--- Loading Model with PyTorch from Hub: {MODEL_NAME} ---")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
trust_remote_code=True
).to(device)
model.eval()
print("--- Model Loaded Successfully ---")
except Exception as e:
raise RuntimeError(f"FATAL: Could not load model from the Hub. Error: {e}")
# --- Helper function for chunking text (Unchanged) ---
def chunk_text(text: str, max_size: int) -> list[str]:
"""Splits text into chunks, trying to break at sentence endings."""
if not text: return []
chunks, start_index = [], 0
while start_index < len(text):
end_index = start_index + max_size
if end_index >= len(text):
chunks.append(text[start_index:])
break
split_pos = text.rfind('.', start_index, end_index)
if split_pos != -1:
chunk, start_index = text[start_index : split_pos + 1], split_pos + 1
else:
chunk, start_index = text[start_index:end_index], end_index
chunks.append(chunk.strip())
return [c for c in chunks if c]
# --- Simplified translation helper for internal use ---
def do_translation(text_to_translate: str) -> str:
"""A clean helper function to run a single translation."""
if not text_to_translate.strip():
return ""
messages = [{"role": "user", "content": text_to_translate}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=2048,
do_sample=True, temperature=0.7, top_p=0.95, top_k=50
)
input_token_len = model_inputs.input_ids.shape[1]
output_ids = generated_ids[0][input_token_len:].tolist()
# We use skip_special_tokens=False to ensure our marker isn't accidentally removed by the decoder
return tokenizer.decode(output_ids, skip_special_tokens=False).strip()
# --- Step 3: Core Translation Function (USING SPECIAL TOKEN MARKER) ---
@spaces.GPU
@torch.no_grad()
def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str:
"""
Processes text by chunks, using a special token to mark the overlap
for clean and reliable removal.
"""
progress(0, desc="Starting...")
print("--- Inference with special token context method started ---")
if not input_text or not input_text.strip():
return "Input text is empty. Please enter some text to translate."
progress(0.1, desc="Chunking Text...")
text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text]
num_chunks = len(text_chunks)
print(f"Processing {num_chunks} chunk(s).")
all_results = []
english_context = ""
for i, chunk in enumerate(text_chunks):
progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}")
final_translation_for_chunk = ""
if english_context:
# Construct a prompt with the context, the special marker, and the new chunk
prompt_with_marker = f"{english_context} {SPECIAL_MARKER} {chunk}"
full_translation = do_translation(prompt_with_marker)
# Find the marker in the translated output
marker_position = full_translation.find(SPECIAL_MARKER)
if marker_position != -1:
# If marker is found, the clean translation is everything after it
print("Special marker found in output. Removing overlap.")
start_of_clean_text = marker_position + len(SPECIAL_MARKER)
final_translation_for_chunk = full_translation[start_of_clean_text:].lstrip()
else:
# If marker is not found, the model likely translated or ignored it.
# We fall back to showing the full translation and the user will see the overlap.
print(f"Warning: Marker '{SPECIAL_MARKER}' not found in translation. Overlap may remain.")
final_translation_for_chunk = full_translation
else:
# For the first chunk, there's no context, so just translate it directly
final_translation_for_chunk = do_translation(chunk)
all_results.append(final_translation_for_chunk)
print(f"Chunk {i+1} processed successfully.")
if context_words > 0:
words = chunk.split()
english_context = " ".join(words[-context_words:])
progress(0.95, desc="Reassembling Results...")
# We must clean up any leftover special tokens from the final joined output
full_output = " ".join(all_results).replace(SPECIAL_MARKER, "")
progress(1.0, desc="Done!")
return full_output
# --- Step 4: Create and Launch the Gradio App ---
print("\n--- Initializing Gradio Interface ---")
app = gr.Interface(
fn=translate_with_chunks,
inputs=[
gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."),
gr.Slider(minimum=128, maximum=1536, value=1024, step=64, label="Character Chunk Size"),
gr.Slider(
minimum=0,
maximum=50,
value=15,
step=5,
label="Context Overlap (Source Words)",
info="Number of English words from the end of the previous chunk to provide as context for the next one. Ensures consistency."
)
],
outputs=gr.Textbox(lines=15, label="Model Output", interactive=False),
title="ERNIE 4.5 Context-Aware Translator",
description="Processes long text using a special token marker to ensure high-quality, consistent translations without duplication.",
allow_flagging="never"
)
if __name__ == "__main__":
app.queue().launch() |