Spaces:
Runtime error
Runtime error
File size: 14,239 Bytes
41568e7 84a3cac 41568e7 d462d74 84a3cac 41568e7 5b88c25 95d6a14 41568e7 95d6a14 84a3cac 95d6a14 7b25576 95d6a14 0e9d987 95d6a14 7b25576 95d6a14 84a3cac 894527f 84a3cac 894527f 84a3cac 95d6a14 894527f 95d6a14 894527f 84a3cac 894527f 95d6a14 894527f 95d6a14 894527f 84a3cac 894527f 84a3cac 894527f 84a3cac 894527f 84a3cac 894527f 84a3cac 894527f 84a3cac 894527f 84a3cac 894527f 84a3cac 95d6a14 84a3cac 95d6a14 894527f 95d6a14 84a3cac 95d6a14 894527f 95d6a14 7b25576 41568e7 95d6a14 41568e7 95d6a14 41568e7 95d6a14 41568e7 84a3cac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
# app.py
import streamlit as st
import cv2
import numpy as np
import easyocr
from PIL import Image
import pillow_avif # enables AVIF support for Pillow
import requests
import os
import logging
import re
import torch
import time
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from deep_translator import GoogleTranslator
from IndicTransToolkit.processor import IndicProcessor # ← use the official pre/post processor
# -------------------- ENV + LOGGING --------------------
MISTRAL_API_KEY = st.secrets.get("MISTRAL_API_KEY")
MISTRAL_AGENT_ID = st.secrets.get("MISTRAL_AGENT_ID")
HF_TOKEN = st.secrets.get("HF_TOKEN")
if not MISTRAL_API_KEY or not MISTRAL_AGENT_ID or not HF_TOKEN:
st.error("❌ Missing required keys in Streamlit secrets. Please set HF_TOKEN, MISTRAL_API_KEY, and MISTRAL_AGENT_ID.")
st.stop()
MISTRAL_URL = "https://api.mistral.ai/v1/agents/completions"
os.environ["NO_PROXY"] = "api.mistral.ai"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# (Optional) keep CPU threads modest on Spaces
try:
torch.set_num_threads(4)
except Exception:
pass
# -------------------- STREAMLIT CONFIG --------------------
st.set_page_config(page_title="OCR + Sanskrit Cleaner & Translator AI", layout="wide")
st.title("📖 OCR for Devanagari - Sanskrit Manuscripts + AI Cleaner + Multi-Language Translation")
st.write(
"Upload a Sanskrit manuscript → OCR → Mistral AI cleans it → "
"Translates into Indic languages using AI4Bharat IndicTrans2 (single model). English via fallback."
)
TARGET_LANGS = ["hin_Deva", "kan_Knda", "tam_Taml", "tel_Telu"] # Hindi, Kannada, Tamil, Telugu
LANG_NAMES = {
"hin_Deva": "Hindi",
"kan_Knda": "Kannada",
"tam_Taml": "Tamil",
"tel_Telu": "Telugu",
}
# -------------------- UTILITIES --------------------
def preprocess_ocr_text(text: str) -> str:
"""Keep only Devanagari letters, spaces, and Sanskrit punctuation."""
text = re.sub(r"[^\u0900-\u097F\s।॥]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def sanitize_for_processor(text: str) -> str:
"""Remove angle brackets and trailing dandas from the *content* (toolkit handles tags)."""
text = text.replace("<", "").replace(">", "")
text = re.sub(r"[।॥]+\s*$", "", text).strip()
return text
def split_sanskrit_verses(text: str) -> list:
"""Split Sanskrit text into verses/sentences using । and ॥ as delimiters.
Preserves the punctuation for re-joining.
"""
# Split on । or ॥, but keep the delimiter with each chunk (except possibly last)
parts = re.split(r'([।॥])', text.strip())
verses = []
current_verse = ""
for part in parts:
if part in ['।', '॥']:
current_verse += part + " " # Add space after punctuation for natural flow
verses.append(current_verse.strip())
current_verse = ""
else:
current_verse += part
if current_verse.strip(): # Add any trailing text
verses.append(current_verse.strip())
return [v.strip() for v in verses if v.strip()]
def call_mistral_cleaner(noisy_text: str, max_retries=3) -> str:
"""Clean OCR Sanskrit text via your Mistral Agent."""
instructions = """You are an AI agent specialized in cleaning, correcting, and restoring Sanskrit text extracted from OCR (Optical Character Recognition) outputs.
Your job is to transform noisy, imperfect, or partially garbled Sanskrit text into a clean, readable, and grammatically correct Sanskrit version written only in Devanagari script.
OBJECTIVE
Correct OCR-induced spelling errors, misrecognized characters, and misplaced diacritics.
Preserve the original Sanskrit meaning and structure.
Maintain the Devanagari script output—never use transliteration or translation.
Output only the corrected Sanskrit text, with no explanations or extra commentary.
RULES
Do not translate Sanskrit text into any other language.
Do not add new words.
Fix errors like:
- Missing or extra characters
- Wrong vowel marks
- Garbled words
- Latin characters
- Bad spacing
Keep Sanskrit grammar intact.
Preserve punctuation symbols like । and ॥.
OUTPUT FORMAT
Only output the cleaned Sanskrit text.
No explanation. No formatting. No English.
Sample Input:
"र) ।श्रीगणेगा यनमध। ।भथरात्रिसृत्तं ) )) पीवष्ियाधित पर्प्री व्याक्षनिनविखव ) )"
Sample Output:
"। श्रीगणेशाय नमः । भद्ररात्रिस्मृतं । पीयूषाधित प्रप्री व्याख्यानविख्यातम् ।" """
for attempt in range(max_retries):
try:
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
payload = {
"agent_id": MISTRAL_AGENT_ID,
"messages": [
{"role": "system", "content": instructions},
{"role": "user", "content": f"Clean this noisy OCR Sanskrit text:\n{noisy_text}"}
]
}
response = requests.post(MISTRAL_URL, headers=headers, json=payload, proxies={"http": "", "https": ""})
if response.status_code == 429:
retry_after = int(response.headers.get("Retry-After", 60))
st.warning(f"⏳ Rate limit hit. Retrying in {retry_after}s... (Attempt {attempt + 1}/{max_retries})")
time.sleep(retry_after)
continue
response.raise_for_status()
result = response.json()
cleaned_text = result.get("choices", [{}])[0].get("message", {}).get("content", "")
return cleaned_text.strip() if cleaned_text else "Error: No output from Agent."
except Exception as e:
logger.error("Error calling Mistral Agent: %s", e)
return f"Error: {str(e)}"
return "Error: Max retries exceeded."
# -------------------- CACHED LOADERS --------------------
@st.cache_resource
def get_easyocr_reader():
return easyocr.Reader(['hi', 'mr', 'ne'], gpu=False)
@st.cache_resource
def load_indic_model_and_tools():
"""
Load ONLY the Indic→Indic model + IndicProcessor.
This avoids the second Indic→English model and removes tag issues entirely.
"""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
st.info("💪 Loading AI4Bharat IndicTrans2 (Indic→Indic) with IndicProcessor...")
model_name_indic = "ai4bharat/indictrans2-indic-indic-1B"
tokenizer_indic = AutoTokenizer.from_pretrained(
model_name_indic, token=HF_TOKEN, trust_remote_code=True
)
model_indic = AutoModelForSeq2SeqLM.from_pretrained(
model_name_indic,
token=HF_TOKEN,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
).to(DEVICE)
# Official pre/post processor for tags & normalization
ip = IndicProcessor(inference=True)
# English fallback translator (very light)
translator = GoogleTranslator(source="auto", target="en")
st.success(f"✅ IndicTrans2 (Indic→Indic) loaded on {DEVICE.upper()}.")
return tokenizer_indic, model_indic, ip, translator, DEVICE
# -------------------- TRANSLATION --------------------
def translate_sanskrit_indic_only(cleaned_sanskrit, tokenizer_indic, model_indic, ip, translator, DEVICE):
"""
Translate Sanskrit → {Hindi, Kannada, Tamil, Telugu} using ONLY the Indic→Indic model.
English is produced by translating the Indic output via deep-translator (lightweight).
"""
try:
src_lang = "san_Deva"
input_text = sanitize_for_processor(cleaned_sanskrit)
# NEW: Split into verses for better handling
input_verses = split_sanskrit_verses(input_text)
st.info(f"📝 Split into {len(input_verses)} verses for accurate translation.")
translations_dict = {}
progress_bar = st.progress(0)
status_text = st.empty()
total_steps = len(TARGET_LANGS)
for i, tgt_lang in enumerate(TARGET_LANGS):
status_text.text(f"Translating Sanskrit → {LANG_NAMES[tgt_lang]}...")
tgt_translations = [] # Collect per-verse translations
for verse in input_verses:
input_sentences = [verse] # One verse per batch
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer_indic(
batch, truncation=True, padding="longest", return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
generated = model_indic.generate(
**inputs,
max_new_tokens=2048, # NEW: Focus on output length (tune if needed)
num_beams=4, # NEW: Beam search for coherence (was 1)
early_stopping=True, # NEW: Stop when done
use_cache=False, # NEW: Enable for speed (was False)
do_sample=False, # NEW: Deterministic with beams
length_penalty=1.0 # NEW: Balanced length
)
decoded = tokenizer_indic.batch_decode(generated, skip_special_tokens=True)
trans_indic_list = ip.postprocess_batch(decoded, lang=tgt_lang)
verse_trans = trans_indic_list[0].strip() if trans_indic_list else ""
tgt_translations.append(verse_trans)
# Join verses with newlines for readability
full_trans_indic = "\n".join(tgt_translations)
# English via lightweight translator from the full Indic output
try:
english_trans = translator.translate(full_trans_indic) if full_trans_indic else translator.translate(input_text)
except Exception:
english_trans = ""
translations_dict[tgt_lang] = {
"indic": full_trans_indic,
"english": english_trans,
"lang_name": LANG_NAMES[tgt_lang],
}
progress_bar.progress((i + 1) / total_steps)
status_text.text("Translation complete!")
return translations_dict
except Exception as e:
st.error(f"❌ Translation failed: {e}")
raise
# -------------------- MAIN APP --------------------
uploaded_file = st.file_uploader("Upload a Sanskrit manuscript image", type=["png", "jpg", "jpeg", "avif"])
if "cleaned_sanskrit" not in st.session_state:
st.session_state.cleaned_sanskrit = ""
if "translations" not in st.session_state:
st.session_state.translations = None
if uploaded_file:
pil_img = Image.open(uploaded_file)
image = np.array(pil_img.convert("L"))
inverted = cv2.bitwise_not(image)
_, mask = cv2.threshold(inverted, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
white_bg = np.ones_like(image) * 255
final_text_only = cv2.bitwise_and(white_bg, white_bg, mask=mask)
col1, col2 = st.columns(2)
with col1:
st.image(pil_img, caption="📷 Original Image", use_container_width=True)
with col2:
st.image(Image.fromarray(final_text_only), caption="🧾 Processed Text-Only Image", use_container_width=True)
st.subheader("🔍 Extracted OCR Text")
with st.spinner("Initializing EasyOCR..."):
try:
reader = get_easyocr_reader()
except Exception as e:
st.error(f"❌ EasyOCR initialization failed: {e}")
st.stop()
results = reader.readtext(image, detail=1, paragraph=True)
extracted_text = " ".join([res[1] for res in results])
if extracted_text.strip():
st.success("✅ OCR Extraction Successful!")
st.text_area("Extracted Text", extracted_text, height=200)
noisy_text = preprocess_ocr_text(extracted_text)
if st.button("✨ Clean OCR Text with Mistral AI Agent"):
with st.spinner("Cleaning Sanskrit text using Mistral Agent..."):
cleaned_sanskrit = call_mistral_cleaner(noisy_text)
if cleaned_sanskrit.startswith("Error"):
st.error(cleaned_sanskrit)
else:
st.session_state.cleaned_sanskrit = cleaned_sanskrit
st.session_state.translations = None
if st.session_state.cleaned_sanskrit:
st.subheader("📜 Cleaned Sanskrit Text")
st.text_area("Cleaned Text", st.session_state.cleaned_sanskrit, height=200)
if st.button("🌐 Translate to Indic Languages + English (1 model)"):
st.warning("⏳ On CPU, first run may take a few minutes while the model loads (cached after).")
with st.spinner("Loading IndicTrans2 (Indic→Indic) and translating..."):
try:
tokenizer_indic, model_indic, ip, translator, DEVICE = load_indic_model_and_tools()
translations = translate_sanskrit_indic_only(
st.session_state.cleaned_sanskrit,
tokenizer_indic, model_indic, ip, translator, DEVICE
)
st.session_state.translations = translations
except Exception as e:
st.exception(e)
if st.session_state.translations:
st.subheader("🌍 Translations")
for tgt_lang, trans_dict in st.session_state.translations.items():
st.write(f"--- **{trans_dict['lang_name']}** ---")
st.write(f"**Sanskrit:** {st.session_state.cleaned_sanskrit}")
st.write(f"**{trans_dict['lang_name']}:** {trans_dict['indic']}")
st.write(f"**English (from {trans_dict['lang_name']}):** {trans_dict['english']}")
st.write("---")
else:
st.warning("⚠️ No text detected. Try uploading a clearer image.")
else:
st.info("👆 Upload an image to start!") |