File size: 8,908 Bytes
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a5f30
 
 
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a5f30
 
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a5f30
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a5f30
 
eb4c2ed
 
 
 
 
 
 
 
 
 
d3a5f30
 
 
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a5f30
eb4c2ed
 
 
 
 
d3a5f30
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a5f30
 
 
 
 
 
 
eb4c2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# rag.py
"""
RAG utilities:
- normalize_text(s): clean up a single string
- normalize_files_in_data(folder): optionally normalize all .txt files in data/
- load_documents(): load and return list of (text, filename)
- build_index(save_index_path='baba.index', save_texts_path='texts.pkl'): build & save index and texts
- ask_baba(question, history): simple retrieval + template answer for Gradio

Usage:
    # normalize files then build
    python rag.py --normalize --build

    # just build from existing files (no normalization)
    python rag.py --build

    # use in app:
    from rag import ask_baba
"""
from sentence_transformers import SentenceTransformer
import faiss
import os
import re
import pickle
import numpy as np
from typing import List, Tuple

# === CONFIG ===
EMBED_MODEL = "all-MiniLM-L6-v2"
INDEX_PATH = "baba.index"
TEXTS_PATH = "texts.pkl"
DEFAULT_FILES = ["milindgatha.txt", "bhaktas.txt", "apologetics.txt", "poc_questions.txt", "satire_offerings.txt"]
DATA_FOLDER = "data"  # will also read *.txt inside data/
EMBED_BATCH_SIZE = 64  # if needed later
TOP_K = 3

# === Model load (singleton) ===
_model = None


def get_model():
    global _model
    if _model is None:
        _model = SentenceTransformer(EMBED_MODEL)
    return _model


# === Normalization utilities ===
def normalize_text(s: str) -> str:
    """
    Normalize a text chunk:
      - replace NBSP
      - convert smart quotes to ASCII quotes
      - convert en/em dashes to hyphens/spaced dash
      - collapse multiple whitespace into single space
      - strip leading/trailing whitespace
      - join broken lines inside a paragraph
    """
    if s is None:
        return ""
    # Replace common unicode nuisances
    s = s.replace("\u00A0", " ")   # NBSP
    # convert common dashes to ASCII
    s = s.replace("\u2013", "-").replace("\u2014", " - ")
    # smart quotes -> ascii
    s = s.replace("β€œ", '"').replace("”", '"').replace("β€˜", "'").replace("’", "'")
    # replace weird ellipsis char
    s = s.replace("\u2026", "...")
    # Remove zero-width & control characters (except newline)
    s = re.sub(r"[\u200B-\u200F\uFEFF]", "", s)
    # Normalize line breaks: join lines within the same paragraph
    # We'll replace sequences of newline+space/newline with a single newline to keep paragraphs,
    # but join internal line breaks into spaces before collapsing whitespace
    paragraphs = re.split(r"\n\s*\n", s)
    cleaned_paragraphs = []
    for p in paragraphs:
        # join internal lines into a single line
        p_joined = " ".join(line.strip() for line in p.splitlines())
        # collapse whitespace
        p_joined = re.sub(r"\s+", " ", p_joined).strip()
        if p_joined:
            cleaned_paragraphs.append(p_joined)
    return "\n\n".join(cleaned_paragraphs)


def normalize_files_in_data(data_folder: str = DATA_FOLDER) -> List[str]:
    """
    Normalize every .txt file inside data_folder in-place.
    Returns list of files processed.
    """
    processed = []
    if not os.path.isdir(data_folder):
        return processed
    for fname in os.listdir(data_folder):
        if not fname.lower().endswith(".txt"):
            continue
        path = os.path.join(data_folder, fname)
        try:
            with open(path, "r", encoding="utf-8") as f:
                text = f.read()
        except UnicodeDecodeError:
            # try latin-1 fallback
            with open(path, "r", encoding="latin-1") as f:
                text = f.read()
        norm = normalize_text(text)
        # only overwrite if changed
        if norm != text:
            with open(path, "w", encoding="utf-8") as f:
                f.write(norm)
        processed.append(path)
    return processed


# === Document loading ===
def load_documents() -> List[Tuple[str, str]]:
    """
    Load documents from DEFAULT_FILES and any .txt files inside DATA_FOLDER.
    Returns list of tuples: (cleaned_text_paragraph, source_filename).
    Splits on paragraph (double newline) boundaries and cleans each chunk.
    """
    docs = []
    files_to_load = list(DEFAULT_FILES)

    # add files from data folder, but don't duplicate names
    if os.path.isdir(DATA_FOLDER):
        for fname in sorted(os.listdir(DATA_FOLDER)):
            if fname.lower().endswith(".txt") and fname not in files_to_load:
                files_to_load.append(os.path.join(DATA_FOLDER, fname))

    for filename in files_to_load:
        # skip if absolute path doesn't exist (allow both root and data/)
        if not os.path.exists(filename):
            # try in data folder if not absolute
            alt = os.path.join(DATA_FOLDER, filename)
            if os.path.exists(alt):
                filename = alt
            else:
                continue
        try:
            with open(filename, "r", encoding="utf-8") as f:
                text = f.read()
        except UnicodeDecodeError:
            with open(filename, "r", encoding="latin-1") as f:
                text = f.read()

        # normalize whole file first
        normalized = normalize_text(text)
        # split into paragraphs (double newline)
        paragraphs = [p.strip() for p in normalized.split("\n\n") if p.strip()]
        for p in paragraphs:
            docs.append((p, os.path.basename(filename)))
    return docs


# === Indexing ===
def build_index(save_index_path: str = INDEX_PATH, save_texts_path: str = TEXTS_PATH, rebuild: bool = True):
    """
    Build embeddings for all loaded documents and save index + texts.
    Overwrites existing index/text files.
    """
    docs = load_documents()
    if not docs:
        raise RuntimeError("No documents found to index. Check files and DATA_FOLDER.")
    texts = [d[0] for d in docs]

    model = get_model()
    # encode in one batch (small doc set). If large, encode in batches.
    embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)

    # create index (inner product) β€” normalize embeddings for cosine similarity
    # normalize embeddings to unit vectors
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    embeddings = embeddings / norms

    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings.astype('float32'))

    # save index and texts
    faiss.write_index(index, save_index_path)
    with open(save_texts_path, "wb") as f:
        pickle.dump(texts, f)

    print(f"[build_index] Saved index -> {save_index_path}, texts -> {save_texts_path}")
    return index, texts


# === Loading saved index ===
def load_index(index_path: str = INDEX_PATH, texts_path: str = TEXTS_PATH):
    if not os.path.exists(index_path) or not os.path.exists(texts_path):
        return None, None
    index = faiss.read_index(index_path)
    with open(texts_path, "rb") as f:
        texts = pickle.load(f)
    return index, texts


# === ask_baba retrieval wrapper ===
_index = None
_texts = None


def ensure_index_loaded():
    global _index, _texts
    if _index is None or _texts is None:
        idx, texts = load_index()
        if idx is None:
            # try to build automatically if no index found
            print("[ensure_index_loaded] No saved index found β€” building now.")
            idx, texts = build_index()
        _index, _texts = idx, texts
    return _index, _texts


def ask_baba(question: str, history=None) -> str:
    """
    Retrieve top-k chunks and format an answer for Gradio.
    Keeps behavior similar to your previous function.
    """
    idx, texts = ensure_index_loaded()
    if idx is None or not texts:
        return "Brahmarshi is silent β€” no index available."

    model = get_model()
    q_emb = model.encode([question], convert_to_numpy=True)
    # normalize
    q_emb = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12)
    D, I = idx.search(q_emb.astype('float32'), TOP_K)
    ids = I[0]
    top_chunks = [texts[i] for i in ids if i < len(texts)]

    answer = "πŸ™ **Brahmarshi Baba Milind says:**\n\n"
    answer += f"O seeker, you ask: *'{question}'* β€” know that even this question is sacred.\n\n"

    for chunk in top_chunks:
        answer += f"- {chunk}\n\n"

    answer += "Faith in Baba Milind triumphs over evidence, every single time."
    return answer


# === CLI entrypoint ===
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="RAG helper: normalize source files and build FAISS index.")
    parser.add_argument("--normalize", action="store_true", help="Normalize all .txt files in data/ folder (in-place).")
    parser.add_argument("--build", action="store_true", help="Build index (and save to disk).")
    args = parser.parse_args()

    if args.normalize:
        processed = normalize_files_in_data(DATA_FOLDER)
        print(f"[normalize] Processed {len(processed)} files: {processed}")

    if args.build:
        build_index()