Gregniuki's picture
Rename app.py to app3.py
6365b18 verified
# --- START OF FILE app.py ---
import sys
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from dotenv import load_dotenv
# --- FIX: Add project root to Python's path ---
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)
# --- Updated Spaces import for Zero-GPU compatibility ---
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces()
# --- Step 1: Hugging Face Authentication ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.")
print("--- Logging in to Hugging Face Hub ---")
login(token=HF_TOKEN)
# --- Step 2: Initialize Model and Tokenizer (Load Once on Startup) ---
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN"
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---")
# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU detected. Using CUDA.")
else:
device = torch.device("cpu")
print("No GPU detected. Using CPU.")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"--- Using dtype: {dtype} ---")
print(f"--- Loading tokenizer from Hub: {MODEL_NAME} ---")
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
# Define the special token from your model's tokenizer
SPECIAL_MARKER = "<|LOC_0|>"
print(f"--- Using special marker for overlap: {SPECIAL_MARKER} ---")
except Exception as e:
raise RuntimeError(f"FATAL: Could not load tokenizer from the Hub. Error: {e}")
print(f"--- Loading Model with PyTorch from Hub: {MODEL_NAME} ---")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
trust_remote_code=True
).to(device)
model.eval()
print("--- Model Loaded Successfully ---")
except Exception as e:
raise RuntimeError(f"FATAL: Could not load model from the Hub. Error: {e}")
# --- Helper function for chunking text (Unchanged) ---
def chunk_text(text: str, max_size: int) -> list[str]:
"""Splits text into chunks, trying to break at sentence endings."""
if not text: return []
chunks, start_index = [], 0
while start_index < len(text):
end_index = start_index + max_size
if end_index >= len(text):
chunks.append(text[start_index:])
break
split_pos = text.rfind('.', start_index, end_index)
if split_pos != -1:
chunk, start_index = text[start_index : split_pos + 1], split_pos + 1
else:
chunk, start_index = text[start_index:end_index], end_index
chunks.append(chunk.strip())
return [c for c in chunks if c]
# --- Simplified translation helper for internal use ---
def do_translation(text_to_translate: str) -> str:
"""A clean helper function to run a single translation."""
if not text_to_translate.strip():
return ""
messages = [{"role": "user", "content": text_to_translate}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=2048,
do_sample=True, temperature=0.7, top_p=0.95, top_k=50
)
input_token_len = model_inputs.input_ids.shape[1]
output_ids = generated_ids[0][input_token_len:].tolist()
# We use skip_special_tokens=False to ensure our marker isn't accidentally removed by the decoder
return tokenizer.decode(output_ids, skip_special_tokens=False).strip()
# --- Step 3: Core Translation Function (USING SPECIAL TOKEN MARKER) ---
@spaces.GPU
@torch.no_grad()
def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str:
"""
Processes text by chunks, using a special token to mark the overlap
for clean and reliable removal.
"""
progress(0, desc="Starting...")
print("--- Inference with special token context method started ---")
if not input_text or not input_text.strip():
return "Input text is empty. Please enter some text to translate."
progress(0.1, desc="Chunking Text...")
text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text]
num_chunks = len(text_chunks)
print(f"Processing {num_chunks} chunk(s).")
all_results = []
english_context = ""
for i, chunk in enumerate(text_chunks):
progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}")
final_translation_for_chunk = ""
if english_context:
# Construct a prompt with the context, the special marker, and the new chunk
prompt_with_marker = f"{english_context} {SPECIAL_MARKER} {chunk}"
full_translation = do_translation(prompt_with_marker)
# Find the marker in the translated output
marker_position = full_translation.find(SPECIAL_MARKER)
if marker_position != -1:
# If marker is found, the clean translation is everything after it
print("Special marker found in output. Removing overlap.")
start_of_clean_text = marker_position + len(SPECIAL_MARKER)
final_translation_for_chunk = full_translation[start_of_clean_text:].lstrip()
else:
# If marker is not found, the model likely translated or ignored it.
# We fall back to showing the full translation and the user will see the overlap.
print(f"Warning: Marker '{SPECIAL_MARKER}' not found in translation. Overlap may remain.")
final_translation_for_chunk = full_translation
else:
# For the first chunk, there's no context, so just translate it directly
final_translation_for_chunk = do_translation(chunk)
all_results.append(final_translation_for_chunk)
print(f"Chunk {i+1} processed successfully.")
if context_words > 0:
words = chunk.split()
english_context = " ".join(words[-context_words:])
progress(0.95, desc="Reassembling Results...")
# We must clean up any leftover special tokens from the final joined output
full_output = " ".join(all_results).replace(SPECIAL_MARKER, "")
progress(1.0, desc="Done!")
return full_output
# --- Step 4: Create and Launch the Gradio App ---
print("\n--- Initializing Gradio Interface ---")
app = gr.Interface(
fn=translate_with_chunks,
inputs=[
gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."),
gr.Slider(minimum=128, maximum=1536, value=1024, step=64, label="Character Chunk Size"),
gr.Slider(
minimum=0,
maximum=50,
value=15,
step=5,
label="Context Overlap (Source Words)",
info="Number of English words from the end of the previous chunk to provide as context for the next one. Ensures consistency."
)
],
outputs=gr.Textbox(lines=15, label="Model Output", interactive=False),
title="ERNIE 4.5 Context-Aware Translator",
description="Processes long text using a special token marker to ensure high-quality, consistent translations without duplication.",
allow_flagging="never"
)
if __name__ == "__main__":
app.queue().launch()