File size: 12,669 Bytes
39006e9
a8d6613
39006e9
 
 
 
812c5ec
39006e9
 
 
 
 
 
 
812c5ec
39006e9
 
 
ab8858c
 
eb99fed
39006e9
 
 
 
 
 
 
 
 
bfad9e2
ab8858c
39006e9
6b94cb2
eb99fed
6b94cb2
 
 
eb99fed
 
6b94cb2
 
 
 
eb99fed
6b94cb2
eb99fed
39006e9
 
 
 
ab8858c
 
 
eb99fed
ab8858c
 
 
 
 
 
 
 
 
eb99fed
ab8858c
 
 
 
 
 
 
 
eb99fed
ab8858c
39006e9
ab8858c
39006e9
6b94cb2
 
eb99fed
 
 
 
6b94cb2
 
812c5ec
6b94cb2
 
 
 
 
812c5ec
 
6b94cb2
812c5ec
 
39006e9
 
812c5ec
 
 
 
 
 
 
39006e9
 
 
ab8858c
eb99fed
39006e9
eb99fed
39006e9
 
eb99fed
39006e9
 
 
eb99fed
39006e9
 
 
 
 
 
 
 
 
 
 
eb99fed
39006e9
 
 
eb99fed
 
ab8858c
eb99fed
39006e9
 
eb99fed
39006e9
eb99fed
ab8858c
 
 
 
 
eb99fed
812c5ec
eb99fed
 
 
ab8858c
812c5ec
 
6b94cb2
 
 
 
 
eb99fed
6b94cb2
39006e9
 
 
 
 
812c5ec
39006e9
 
 
 
812c5ec
 
39006e9
ab8858c
39006e9
6b94cb2
39006e9
 
6b94cb2
eb99fed
6b94cb2
ab8858c
 
39006e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb99fed
39006e9
 
 
 
 
 
 
 
 
 
ab8858c
39006e9
 
 
 
 
 
 
 
 
 
ab8858c
 
39006e9
ab8858c
39006e9
 
ab8858c
39006e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812c5ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# app.py
import streamlit as st
import cv2
import numpy as np
import easyocr
from PIL import Image
import pillow_avif # enables AVIF support for Pillow
import requests
import os
import logging
import re
import torch
import time
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from deep_translator import GoogleTranslator # Replaces googletrans
# -------------------- ENV + LOGGING --------------------
MISTRAL_API_KEY = st.secrets.get("MISTRAL_API_KEY")
MISTRAL_AGENT_ID = st.secrets.get("MISTRAL_AGENT_ID")
HF_TOKEN = st.secrets.get("HF_TOKEN")
if not MISTRAL_API_KEY or not MISTRAL_AGENT_ID or not HF_TOKEN:
    st.error("❌ Missing required keys in Streamlit secrets. Please check HF_TOKEN, MISTRAL_API_KEY, and MISTRAL_AGENT_ID.")
    st.stop()
MISTRAL_URL = "https://api.mistral.ai/v1/agents/completions"
os.environ["NO_PROXY"] = "api.mistral.ai"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# -------------------- STREAMLIT CONFIG --------------------
st.set_page_config(page_title="OCR + Sanskrit Cleaner & Translator AI", layout="wide")
st.title("📖 OCR for Devanagari - Sanskrit Manuscripts + AI Cleaner + Multi-Language Translation")
st.write(
    "Upload a Sanskrit manuscript → OCR → Mistral AI cleans it → "
    "Translates into Indic languages + English using AI4Bharat IndicTrans2 models."
)
# -------------------- CONSTANTS --------------------
VALID_TAGS = [
    "asm_Beng", "ben_Beng", "guj_Gujr", "hin_Deva", "kan_Knda",
    "mal_Mlym", "mar_Deva", "nep_Deva", "ori_Orya", "pan_Guru",
    "san_Deva", "tam_Taml", "tel_Telu", "eng_Latn"
]
TARGET_LANGS = ["hin_Deva", "kan_Knda", "tam_Taml", "tel_Telu"]
LANG_NAMES = {
    "hin_Deva": "Hindi",
    "kan_Knda": "Kannada",
    "tam_Taml": "Tamil",
    "tel_Telu": "Telugu"
}
# -------------------- LOAD MODELS --------------------
@st.cache_resource
def load_translation_models():
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    try:
        st.info("💪 Loading AI4Bharat IndicTrans2 models (requires Hugging Face token)...")
        # Indic→Indic
        model_name_indic = "ai4bharat/indictrans2-indic-indic-1B"
        tokenizer_indic = AutoTokenizer.from_pretrained(model_name_indic, token=HF_TOKEN, trust_remote_code=True)
        model_indic = AutoModelForSeq2SeqLM.from_pretrained(
            model_name_indic,
            token=HF_TOKEN,
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(DEVICE)
        # Indic→English
        model_name_en = "ai4bharat/indictrans2-indic-en-1B"
        tokenizer_en = AutoTokenizer.from_pretrained(model_name_en, token=HF_TOKEN, trust_remote_code=True)
        model_en = AutoModelForSeq2SeqLM.from_pretrained(
            model_name_en,
            token=HF_TOKEN,
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True
        ).to(DEVICE)
        st.success(f"✅ Models loaded successfully on {DEVICE.upper()}.")
        translator = GoogleTranslator(source="auto", target="en")
        return tokenizer_indic, model_indic, tokenizer_en, model_en, translator, DEVICE
    except Exception as e:
        st.error(f"❌ Model loading failed: {e}")
        raise
# -------------------- HELPERS --------------------
def sanitize_text_for_tags(text: str) -> str:
    """Clean Sanskrit text to remove unwanted symbols before tagging."""
    text = re.sub(r"[<>]", "", text)
    text = re.sub(r"[।॥]+$", "", text)
    text = text.strip()
    return text
def manual_preprocess_batch(input_sentences, src_lang: str, tgt_lang: str):
    """Format text for IndicTrans2 with space-separated tags (per official docs)."""
    assert src_lang in VALID_TAGS, f"Invalid source language tag: {src_lang}"
    assert tgt_lang in VALID_TAGS, f"Invalid target language tag: {tgt_lang}"
    cleaned_batch = []
    for sent in input_sentences:
        sent = sanitize_text_for_tags(sent)
        # Format: "san_Deva eng_Latn cleaned_text" (no < > or </s>)
        cleaned_batch.append(f"{src_lang} {tgt_lang} {sent.strip()}")
    return cleaned_batch
def manual_postprocess_batch(generated_tokens, tgt_lang: str = None):
    """Postprocess to remove leading tgt_lang tag from generated text."""
    translations = []
    for tokens in generated_tokens:
        cleaned = tokens.strip()
        # Remove leading tgt_lang (e.g., "eng_Latn translated_text" -> "translated_text")
        # Fallback to removing first word+space if tgt_lang unknown
        if tgt_lang:
            cleaned = re.sub(rf"^{re.escape(tgt_lang)}\s+", "", cleaned)
        else:
            cleaned = re.sub(r"^\S+\s+", "", cleaned)
        translations.append(cleaned)
    return translations
def preprocess_ocr_text(text: str) -> str:
    """Keep only Devanagari letters, spaces, and Sanskrit punctuation."""
    return re.sub(r"[^\u0900-\u097F\s।॥]", "", text)
def call_mistral_cleaner(noisy_text: str, max_retries=3) -> str:
    """Clean OCR Sanskrit text via Mistral Agent."""
    for attempt in range(max_retries):
        try:
            headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
            payload = {
                "agent_id": MISTRAL_AGENT_ID,
                "messages": [
                    {"role": "user", "content": f"Clean this noisy OCR Sanskrit text: {noisy_text}\n\nOutput only the cleaned Devanagari text."}
                ]
            }
            response = requests.post(MISTRAL_URL, headers=headers, json=payload, proxies={"http": "", "https": ""})
            if response.status_code == 429:
                retry_after = int(response.headers.get("Retry-After", 60))
                st.warning(f"⏳ Rate limit hit. Retrying in {retry_after}s... (Attempt {attempt + 1}/{max_retries})")
                time.sleep(retry_after)
                continue
            response.raise_for_status()
            result = response.json()
            cleaned_text = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            return cleaned_text.strip() if cleaned_text else "Error: No output from Agent."
        except Exception as e:
            logger.error("Error calling Mistral Agent: %s", e)
            return f"Error: {str(e)}"
    return "Error: Max retries exceeded."
# -------------------- TRANSLATION --------------------
def translate_sanskrit(cleaned_sanskrit, tokenizer_indic, model_indic, tokenizer_en, model_en, translator, DEVICE):
    """Translate Sanskrit → Indic + English using IndicTrans2 + fallback."""
    try:
        src_lang = "san_Deva"
        input_sentences = [sanitize_text_for_tags(cleaned_sanskrit)]
        translations_dict = {}
        # English translation
        tgt_lang_en = "eng_Latn"
        batch_en = manual_preprocess_batch(input_sentences, src_lang, tgt_lang_en)
        inputs_en = tokenizer_en(batch_en, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            generated_en = model_en.generate(
                **inputs_en,
                max_length=2048, # ✅ Increased from 512 to handle longer Sanskrit texts
                num_beams=5,
                num_return_sequences=1,
                use_cache=False
            )
        english_raw = tokenizer_en.batch_decode(generated_en, skip_special_tokens=True)[0].strip()
        english_trans = manual_postprocess_batch([english_raw], tgt_lang_en)[0]  # Remove leading tag
        if not english_trans:
            try:
                english_trans = translator.translate(cleaned_sanskrit)
            except Exception:
                english_trans = ""
        # Indic translations
        for tgt_lang in TARGET_LANGS:
            batch = manual_preprocess_batch(input_sentences, src_lang, tgt_lang)
            inputs = tokenizer_indic(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                generated_tokens = model_indic.generate(
                    **inputs,
                    max_length=2048, # ✅ Increased for better output coverage
                    num_beams=5,
                    num_return_sequences=1,
                    use_cache=False
                )
            indic_raw = tokenizer_indic.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
            trans_indic = manual_postprocess_batch([indic_raw], tgt_lang)[0]  # Remove leading tag
            translations_dict[tgt_lang] = {
                "indic": trans_indic,
                "english": english_trans,
                "lang_name": LANG_NAMES[tgt_lang]
            }
        return translations_dict
    except AssertionError as e:
        st.error(f"❌ Language tag error: {e}. Check preprocessing & tags.")
        raise
    except Exception as e:
        st.error(f"❌ Translation failed: {e}")
        raise
# -------------------- MAIN APP --------------------
uploaded_file = st.file_uploader("Upload a Sanskrit manuscript image", type=["png", "jpg", "jpeg", "avif"])
if "cleaned_sanskrit" not in st.session_state:
    st.session_state.cleaned_sanskrit = ""
if "translations" not in st.session_state:
    st.session_state.translations = None
if uploaded_file:
    pil_img = Image.open(uploaded_file)
    image = np.array(pil_img.convert("L"))
    inverted = cv2.bitwise_not(image)
    _, mask = cv2.threshold(inverted, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    white_bg = np.ones_like(image) * 255
    final_text_only = cv2.bitwise_and(white_bg, white_bg, mask=mask)
    col1, col2 = st.columns(2)
    with col1:
        st.image(pil_img, caption="📷 Original Image", use_container_width=True)
    with col2:
        st.image(Image.fromarray(final_text_only), caption="🧾 Processed Text-Only Image", use_container_width=True)
    st.subheader("🔍 Extracted OCR Text")
    with st.spinner("Initializing EasyOCR..."):
        try:
            reader = easyocr.Reader(['hi', 'mr', 'ne'], gpu=False)
        except Exception as e:
            st.error(f"❌ EasyOCR initialization failed: {e}")
            st.stop()
    results = reader.readtext(image, detail=1, paragraph=True)
    extracted_text = " ".join([res[1] for res in results])
    if extracted_text.strip():
        st.success("✅ OCR Extraction Successful!")
        st.text_area("Extracted Text", extracted_text, height=200)
        noisy_text = preprocess_ocr_text(extracted_text)
        if st.button("✨ Clean OCR Text with Mistral AI Agent"):
            with st.spinner("Cleaning Sanskrit text using Mistral Agent..."):
                cleaned_sanskrit = call_mistral_cleaner(noisy_text)
                if cleaned_sanskrit.startswith("Error"):
                    st.error(cleaned_sanskrit)
                else:
                    st.session_state.cleaned_sanskrit = cleaned_sanskrit
                    st.session_state.translations = None
        if st.session_state.cleaned_sanskrit:
            st.subheader("📜 Cleaned Sanskrit Text")
            st.text_area("Cleaned Text", st.session_state.cleaned_sanskrit, height=200)
            if st.button("🌐 Translate to Indic Languages + English"):
                st.warning("⏳ Translation on CPU may take 2–5 minutes (models load once).")
                with st.spinner("Loading AI4Bharat models and generating translations..."):
                    try:
                        tokenizer_indic, model_indic, tokenizer_en, model_en, translator, DEVICE = load_translation_models()
                        translations = translate_sanskrit(
                            st.session_state.cleaned_sanskrit,
                            tokenizer_indic, model_indic, tokenizer_en, model_en, translator, DEVICE
                        )
                        st.session_state.translations = translations
                    except Exception as e:
                        st.exception(e)
        if st.session_state.translations:
            st.subheader("🌍 Translations")
            for tgt_lang, trans_dict in st.session_state.translations.items():
                st.write(f"--- **{trans_dict['lang_name']}** ---")
                st.write(f"**Sanskrit:** {st.session_state.cleaned_sanskrit}")
                st.write(f"**{trans_dict['lang_name']}:** {trans_dict['indic']}")
                st.write(f"**English:** {trans_dict['english']}")
                st.write("---")
    else:
        st.warning("⚠️ No text detected. Try uploading a clearer image.")
else:
    st.info("👆 Upload an image to start!")