Spaces:
Running
on
Zero
Running
on
Zero
# --- START OF FILE app.py --- | |
import sys | |
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login | |
from dotenv import load_dotenv | |
# --- FIX: Add project root to Python's path --- | |
project_root = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.insert(0, project_root) | |
# --- Updated Spaces import for Zero-GPU compatibility --- | |
try: | |
import spaces | |
print("'spaces' module imported successfully.") | |
except ImportError: | |
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") | |
class DummySpaces: | |
def GPU(self, *args, **kwargs): | |
def decorator(func): | |
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") | |
return func | |
return decorator | |
spaces = DummySpaces() | |
# --- Step 1: Hugging Face Authentication --- | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.") | |
print("--- Logging in to Hugging Face Hub ---") | |
login(token=HF_TOKEN) | |
# --- Step 2: Initialize Model and Tokenizer (Load Once on Startup) --- | |
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN" | |
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---") | |
# --- Device Setup (Zero GPU Support) --- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("GPU detected. Using CUDA.") | |
else: | |
device = torch.device("cpu") | |
print("No GPU detected. Using CPU.") | |
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 | |
print(f"--- Using dtype: {dtype} ---") | |
print(f"--- Loading tokenizer from Hub: {MODEL_NAME} ---") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
trust_remote_code=True | |
) | |
print("--- Tokenizer Loaded Successfully ---") | |
except Exception as e: | |
raise RuntimeError(f"FATAL: Could not load tokenizer from the Hub. Error: {e}") | |
print(f"--- Loading Model with PyTorch from Hub: {MODEL_NAME} ---") | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=dtype, | |
trust_remote_code=True | |
).to(device) | |
model.eval() | |
print("--- Model Loaded Successfully ---") | |
except Exception as e: | |
raise RuntimeError(f"FATAL: Could not load model from the Hub. Error: {e}") | |
# --- Helper function for chunking text (Unchanged) --- | |
def chunk_text(text: str, max_size: int) -> list[str]: | |
"""Splits text into chunks, trying to break at sentence endings.""" | |
if not text: return [] | |
chunks, start_index = [], 0 | |
while start_index < len(text): | |
end_index = start_index + max_size | |
if end_index >= len(text): | |
chunks.append(text[start_index:]) | |
break | |
split_pos = text.rfind('.', start_index, end_index) | |
if split_pos != -1: | |
chunk, start_index = text[start_index : split_pos + 1], split_pos + 1 | |
else: | |
chunk, start_index = text[start_index:end_index], end_index | |
chunks.append(chunk.strip()) | |
return [c for c in chunks if c] | |
# --- Simplified translation helper for internal use --- | |
def do_translation(text_to_translate: str) -> str: | |
"""A clean helper function to run a single translation.""" | |
if not text_to_translate.strip(): | |
return "" | |
messages = [{"role": "user", "content": text_to_translate}] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=2048, | |
do_sample=True, temperature=0.7, top_p=0.95, top_k=50 | |
) | |
input_token_len = model_inputs.input_ids.shape[1] | |
output_ids = generated_ids[0][input_token_len:].tolist() | |
return tokenizer.decode(output_ids, skip_special_tokens=True).strip() | |
# --- Step 3: Core Translation Function (USING ROBUST 'DIFF' ALGORITHM) --- | |
def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str: | |
""" | |
Processes text by chunks, using a robust word-by-word 'diff' algorithm | |
to reliably find and remove the overlapping translation. | |
""" | |
progress(0, desc="Starting...") | |
print("--- Inference with robust 'diff' context method started ---") | |
if not input_text or not input_text.strip(): | |
return "Input text is empty. Please enter some text to translate." | |
progress(0.1, desc="Chunking Text...") | |
text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text] | |
num_chunks = len(text_chunks) | |
print(f"Processing {num_chunks} chunk(s).") | |
all_results = [] | |
english_context = "" | |
for i, chunk in enumerate(text_chunks): | |
progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}") | |
prompt_with_context = (english_context + " " + chunk).strip() | |
full_translation = do_translation(prompt_with_context) | |
final_translation_for_chunk = full_translation | |
if english_context: | |
translated_context = do_translation(english_context) | |
# --- Start of the Diff Algorithm --- | |
context_words_list = translated_context.split() | |
full_translation_words_list = full_translation.split() | |
# Find the first point of difference | |
overlap_len_in_words = 0 | |
for i in range(min(len(context_words_list), len(full_translation_words_list))): | |
# Compare words robustly (lowercase, strip punctuation) | |
if context_words_list[i].strip('.,!?;:').lower() != full_translation_words_list[i].strip('.,!?;:').lower(): | |
break | |
overlap_len_in_words += 1 | |
# The new text starts after the matching words | |
final_translation_for_chunk = " ".join(full_translation_words_list[overlap_len_in_words:]) | |
# --- End of the Diff Algorithm --- | |
all_results.append(final_translation_for_chunk) | |
print(f"Chunk {i+1} processed successfully.") | |
if context_words > 0: | |
words = chunk.split() | |
english_context = " ".join(words[-context_words:]) | |
progress(0.95, desc="Reassembling Results...") | |
full_output = " ".join(all_results) | |
progress(1.0, desc="Done!") | |
return full_output | |
# --- Step 4: Create and Launch the Gradio App --- | |
print("\n--- Initializing Gradio Interface ---") | |
app = gr.Interface( | |
fn=translate_with_chunks, | |
inputs=[ | |
gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."), | |
gr.Slider(minimum=128, maximum=1536, value=1024, step=64, label="Character Chunk Size"), | |
gr.Slider( | |
minimum=0, | |
maximum=50, | |
value=15, | |
step=5, | |
label="Context Overlap (Source Words)", | |
info="Number of English words from the end of the previous chunk to provide as context for the next one. Ensures consistency." | |
) | |
], | |
outputs=gr.Textbox(lines=15, label="Model Output", interactive=False), | |
title="ERNIE 4.5 Context-Aware Translator", | |
description="Processes long text using a robust 'diff' algorithm to ensure high-quality, consistent translations without duplication.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
app.queue().launch() |