sanskrit / app.py
premjavali05's picture
Update app.py
a1244f5 verified
import gradio as gr
import cv2
import numpy as np
import easyocr
from PIL import Image
import pillow_avif
import requests
import os
import logging
import re
import torch
import time
from functools import lru_cache
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from deep_translator import GoogleTranslator
from IndicTransToolkit.processor import IndicProcessor
# -------------------- ENV + LOGGING --------------------
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
MISTRAL_AGENT_ID = os.getenv("MISTRAL_AGENT_ID")
HF_TOKEN = os.getenv("HF_TOKEN")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.set_num_threads(4)
# -------------------- UTILITIES --------------------
def preprocess_ocr_text(text: str) -> str:
text = re.sub(r"[^\u0900-\u097F\s।॥]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def sanitize_for_processor(text: str) -> str:
text = text.replace("<", "").replace(">", "")
text = re.sub(r"[।॥]+\s*$", "", text).strip()
return text
def split_sanskrit_verses(text: str) -> list:
parts = re.split(r'([।॥])', text.strip())
verses, cur = [], ""
for p in parts:
if p in ["।", "॥"]:
cur += p + " "
verses.append(cur.strip())
cur = ""
else:
cur += p
if cur.strip():
verses.append(cur.strip())
return [v.strip() for v in verses if v.strip()]
def call_mistral_cleaner(noisy_text: str) -> str:
instructions = """You are an AI agent specialized in cleaning Sanskrit OCR text..."""
try:
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}"}
payload = {
"agent_id": MISTRAL_AGENT_ID,
"messages": [
{"role": "system", "content": instructions},
{"role": "user", "content": f"Clean this noisy OCR Sanskrit text:\n{noisy_text}"}
]
}
response = requests.post(
"https://api.mistral.ai/v1/agents/completions",
json=payload,
headers=headers,
proxies={"http": "", "https": ""}
)
response.raise_for_status()
cleaned = response.json()["choices"][0]["message"]["content"]
return cleaned.strip()
except Exception as e:
return f"Error: {str(e)}"
# -------------------- CPU-ONLY MODEL LOADERS --------------------
@lru_cache(maxsize=1)
def load_easyocr():
return easyocr.Reader(["hi", "mr", "ne"], gpu=False)
@lru_cache(maxsize=1)
def load_indic_model():
"""
Load AI4Bharat IndicTrans2 in PURE CPU MODE.
"""
DEVICE = "cpu" # ← FORCE CPU
model_name = "ai4bharat/indictrans2-indic-indic-1B"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=HF_TOKEN,
trust_remote_code=True
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
token=HF_TOKEN,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32 # ← CPU-support only
).to(DEVICE)
ip = IndicProcessor(inference=True)
translator = GoogleTranslator(source="auto", target="en")
return tokenizer, model, ip, translator, DEVICE
# -------------------- OCR STEP --------------------
def run_ocr(img):
if img is None:
return "No image uploaded."
reader = load_easyocr()
np_img = np.array(img.convert("L"))
results = reader.readtext(np_img, detail=1, paragraph=True)
extracted = " ".join([res[1] for res in results])
return extracted
# -------------------- CLEANING STEP --------------------
def clean_sanskrit(text):
if not text.strip():
return "No text found."
filtered = preprocess_ocr_text(text)
cleaned = call_mistral_cleaner(filtered)
return cleaned
# -------------------- TRANSLATION STEP --------------------
TARGET_LANGS = ["hin_Deva", "kan_Knda", "tam_Taml", "tel_Telu"]
LANG_NAMES = {
"hin_Deva": "Hindi",
"kan_Knda": "Kannada",
"tam_Taml": "Tamil",
"tel_Telu": "Telugu"
}
def translate(cleaned_text):
tokenizer, model, ip, translator, DEVICE = load_indic_model()
verses = split_sanskrit_verses(sanitize_for_processor(cleaned_text))
output = {}
for tgt in TARGET_LANGS:
per_verse = []
for verse in verses:
# Preprocessing
batch = ip.preprocess_batch([verse], src_lang="san_Deva", tgt_lang=tgt)
inputs = tokenizer(
batch,
return_tensors="pt",
padding="longest",
truncation=True
).to(DEVICE)
# CPU-friendly settings
with torch.no_grad():
generated = model.generate(
**inputs,
max_new_tokens=512, # reduce CPU load
num_beams=3, # balanced quality/speed
early_stopping=True,
do_sample=False
)
decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)
final = ip.postprocess_batch(decoded, lang=tgt)[0]
per_verse.append(final)
full = "\n".join(per_verse)
try:
english = translator.translate(full)
except:
english = ""
output[LANG_NAMES[tgt]] = {"indic": full, "english": english}
return output
# -------------------- UI (GRADIO BLOCKS) --------------------
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 📖 TimeLens - Sanskrit OCR + Cleanup + Translation (CPU Version)")
with gr.Row():
img_in = gr.Image(type="pil", label="Upload Manuscript Image")
extracted_box = gr.Textbox(label="Extracted OCR Text", lines=8)
ocr_btn = gr.Button("🔍 Extract OCR")
with gr.Row():
cleaned_box = gr.Textbox(label="Cleaned Sanskrit Text", lines=8)
clean_btn = gr.Button("✨ Clean Sanskrit (Mistral)")
with gr.Row():
trans_output = gr.JSON(label="Translations Output")
trans_btn = gr.Button("🌐 Translate to Indic Languages + English")
# Bind events
ocr_btn.click(run_ocr, inputs=img_in, outputs=extracted_box)
clean_btn.click(clean_sanskrit, inputs=extracted_box, outputs=cleaned_box)
trans_btn.click(translate, inputs=cleaned_box, outputs=trans_output)
demo.launch()