Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,774 Bytes
c893ff9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# --- START OF FILE app.py ---
import sys
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from dotenv import load_dotenv
# --- FIX: Add project root to Python's path ---
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)
# --- Updated Spaces import for Zero-GPU compatibility ---
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces()
# --- Step 1: Hugging Face Authentication ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.")
print("--- Logging in to Hugging Face Hub ---")
login(token=HF_TOKEN)
# --- Step 2: Initialize Model and Tokenizer (Load Once on Startup) ---
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN"
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---")
# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"--- Using device: {device}, dtype: {dtype} ---")
# --- Load Tokenizer and Define Marker ---
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Use a semantically correct separator token from your model's vocab
MARKER_STRING = "<|LOC_SEP|>"
marker_token_id = tokenizer.convert_tokens_to_ids(MARKER_STRING)
if marker_token_id == tokenizer.unk_token_id:
raise ValueError(f"Marker token '{MARKER_STRING}' not found in tokenizer vocabulary!")
print(f"--- Using marker '{MARKER_STRING}' (ID: {marker_token_id}) for precise overlap removal. ---")
except Exception as e:
raise RuntimeError(f"FATAL: Could not load tokenizer. Error: {e}")
# --- Load Model ---
try:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, trust_remote_code=True).to(device)
model.eval()
print("--- Model Loaded Successfully ---")
except Exception as e:
raise RuntimeError(f"FATAL: Could not load model. Error: {e}")
# --- Helper function for chunking text (Unchanged) ---
def chunk_text(text: str, max_size: int) -> list[str]:
if not text: return []
chunks, start_index = [], 0
while start_index < len(text):
end_index = start_index + max_size
if end_index >= len(text):
chunks.append(text[start_index:]); break
split_pos = text.rfind('.', start_index, end_index)
if split_pos != -1:
chunk, start_index = text[start_index : split_pos + 1], split_pos + 1
else:
chunk, start_index = text[start_index:end_index], end_index
chunks.append(chunk.strip())
return [c for c in chunks if c]
# --- Modified translation helper to return IDs ---
def do_translation(text_to_translate: str) -> tuple[str, list[int]]:
"""Runs a single translation and returns both the decoded string and the token IDs."""
if not text_to_translate.strip(): return "", []
messages = [{"role": "user", "content": text_to_translate}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device)
generated_ids_tensor = model.generate(**model_inputs, max_new_tokens=2048, do_sample=True, temperature=0.7, top_p=0.95, top_k=50)
input_token_len = model_inputs.input_ids.shape[1]
output_ids = generated_ids_tensor[0][input_token_len:].tolist()
decoded_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
return decoded_text, output_ids
# --- Step 3: Core Translation Function (PRECISE TOKEN ID METHOD + DIFF FALLBACK) ---
@spaces.GPU
@torch.no_grad()
def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str:
"""
Processes chunks using a precise token ID search for overlap removal, with a robust 'diff' fallback.
"""
progress(0, desc="Starting...")
if not input_text or not input_text.strip(): return "Input text is empty. Please enter some text to translate."
text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text]
num_chunks = len(text_chunks)
print(f"Processing {num_chunks} chunk(s).")
all_results = []
english_context = ""
for i, chunk in enumerate(text_chunks):
progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}")
if not english_context:
# First chunk: no context needed
final_translation_for_chunk, _ = do_translation(chunk)
else:
prompt_with_marker = f"{english_context} {MARKER_STRING} {chunk}"
full_translation_str, full_translation_ids = do_translation(prompt_with_marker)
# --- Primary Method: Search for Marker Token ID ---
try:
marker_index = full_translation_ids.index(marker_token_id)
print("Precise marker token ID found. Slicing output.")
clean_ids = full_translation_ids[marker_index + 1:]
final_translation_for_chunk = tokenizer.decode(clean_ids, skip_special_tokens=True).strip()
# --- Fallback Method: 'Diff' Algorithm ---
except ValueError:
print(f"Warning: Marker token ID {marker_token_id} not in output. Falling back to diff algorithm.")
translated_context_str, _ = do_translation(english_context)
context_words_list = translated_context_str.split()
full_translation_words_list = full_translation_str.split()
overlap_len_in_words = 0
for j in range(min(len(context_words_list), len(full_translation_words_list))):
if context_words_list[j].strip('.,!?;:').lower() != full_translation_words_list[j].strip('.,!?;:').lower():
break
overlap_len_in_words += 1
final_translation_for_chunk = " ".join(full_translation_words_list[overlap_len_in_words:])
all_results.append(final_translation_for_chunk)
print(f"Chunk {i+1} processed successfully.")
if context_words > 0:
english_context = " ".join(chunk.split()[-context_words:])
full_output = " ".join(all_results)
progress(1.0, desc="Done!")
return full_output
# --- Step 4: Create and Launch the Gradio App ---
print("\n--- Initializing Gradio Interface ---")
app = gr.Interface(
fn=translate_with_chunks,
inputs=[
gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."),
gr.Slider(minimum=128, maximum=1536, value=1024, step=64, label="Character Chunk Size"),
gr.Slider(minimum=0, maximum=50, value=15, step=5, label="Context Overlap (Source Words)")
],
outputs=gr.Textbox(lines=15, label="Model Output", interactive=False),
title="ERNIE 4.5 Context-Aware Translator",
description="Processes long text using a precise token-based method with a robust 'diff' fallback to ensure high-quality, consistent translations.",
allow_flagging="never"
)
if __name__ == "__main__":
app.queue().launch() |