CrimsonElephant's picture
Normalize text, update RAG: add normalization + index build
eb4c2ed
# rag.py
"""
RAG utilities:
- normalize_text(s): clean up a single string
- normalize_files_in_data(folder): optionally normalize all .txt files in data/
- load_documents(): load and return list of (text, filename)
- build_index(save_index_path='baba.index', save_texts_path='texts.pkl'): build & save index and texts
- ask_baba(question, history): simple retrieval + template answer for Gradio
Usage:
# normalize files then build
python rag.py --normalize --build
# just build from existing files (no normalization)
python rag.py --build
# use in app:
from rag import ask_baba
"""
from sentence_transformers import SentenceTransformer
import faiss
import os
import re
import pickle
import numpy as np
from typing import List, Tuple
# === CONFIG ===
EMBED_MODEL = "all-MiniLM-L6-v2"
INDEX_PATH = "baba.index"
TEXTS_PATH = "texts.pkl"
DEFAULT_FILES = ["milindgatha.txt", "bhaktas.txt", "apologetics.txt", "poc_questions.txt", "satire_offerings.txt"]
DATA_FOLDER = "data" # will also read *.txt inside data/
EMBED_BATCH_SIZE = 64 # if needed later
TOP_K = 3
# === Model load (singleton) ===
_model = None
def get_model():
global _model
if _model is None:
_model = SentenceTransformer(EMBED_MODEL)
return _model
# === Normalization utilities ===
def normalize_text(s: str) -> str:
"""
Normalize a text chunk:
- replace NBSP
- convert smart quotes to ASCII quotes
- convert en/em dashes to hyphens/spaced dash
- collapse multiple whitespace into single space
- strip leading/trailing whitespace
- join broken lines inside a paragraph
"""
if s is None:
return ""
# Replace common unicode nuisances
s = s.replace("\u00A0", " ") # NBSP
# convert common dashes to ASCII
s = s.replace("\u2013", "-").replace("\u2014", " - ")
# smart quotes -> ascii
s = s.replace("β€œ", '"').replace("”", '"').replace("β€˜", "'").replace("’", "'")
# replace weird ellipsis char
s = s.replace("\u2026", "...")
# Remove zero-width & control characters (except newline)
s = re.sub(r"[\u200B-\u200F\uFEFF]", "", s)
# Normalize line breaks: join lines within the same paragraph
# We'll replace sequences of newline+space/newline with a single newline to keep paragraphs,
# but join internal line breaks into spaces before collapsing whitespace
paragraphs = re.split(r"\n\s*\n", s)
cleaned_paragraphs = []
for p in paragraphs:
# join internal lines into a single line
p_joined = " ".join(line.strip() for line in p.splitlines())
# collapse whitespace
p_joined = re.sub(r"\s+", " ", p_joined).strip()
if p_joined:
cleaned_paragraphs.append(p_joined)
return "\n\n".join(cleaned_paragraphs)
def normalize_files_in_data(data_folder: str = DATA_FOLDER) -> List[str]:
"""
Normalize every .txt file inside data_folder in-place.
Returns list of files processed.
"""
processed = []
if not os.path.isdir(data_folder):
return processed
for fname in os.listdir(data_folder):
if not fname.lower().endswith(".txt"):
continue
path = os.path.join(data_folder, fname)
try:
with open(path, "r", encoding="utf-8") as f:
text = f.read()
except UnicodeDecodeError:
# try latin-1 fallback
with open(path, "r", encoding="latin-1") as f:
text = f.read()
norm = normalize_text(text)
# only overwrite if changed
if norm != text:
with open(path, "w", encoding="utf-8") as f:
f.write(norm)
processed.append(path)
return processed
# === Document loading ===
def load_documents() -> List[Tuple[str, str]]:
"""
Load documents from DEFAULT_FILES and any .txt files inside DATA_FOLDER.
Returns list of tuples: (cleaned_text_paragraph, source_filename).
Splits on paragraph (double newline) boundaries and cleans each chunk.
"""
docs = []
files_to_load = list(DEFAULT_FILES)
# add files from data folder, but don't duplicate names
if os.path.isdir(DATA_FOLDER):
for fname in sorted(os.listdir(DATA_FOLDER)):
if fname.lower().endswith(".txt") and fname not in files_to_load:
files_to_load.append(os.path.join(DATA_FOLDER, fname))
for filename in files_to_load:
# skip if absolute path doesn't exist (allow both root and data/)
if not os.path.exists(filename):
# try in data folder if not absolute
alt = os.path.join(DATA_FOLDER, filename)
if os.path.exists(alt):
filename = alt
else:
continue
try:
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
except UnicodeDecodeError:
with open(filename, "r", encoding="latin-1") as f:
text = f.read()
# normalize whole file first
normalized = normalize_text(text)
# split into paragraphs (double newline)
paragraphs = [p.strip() for p in normalized.split("\n\n") if p.strip()]
for p in paragraphs:
docs.append((p, os.path.basename(filename)))
return docs
# === Indexing ===
def build_index(save_index_path: str = INDEX_PATH, save_texts_path: str = TEXTS_PATH, rebuild: bool = True):
"""
Build embeddings for all loaded documents and save index + texts.
Overwrites existing index/text files.
"""
docs = load_documents()
if not docs:
raise RuntimeError("No documents found to index. Check files and DATA_FOLDER.")
texts = [d[0] for d in docs]
model = get_model()
# encode in one batch (small doc set). If large, encode in batches.
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
# create index (inner product) β€” normalize embeddings for cosine similarity
# normalize embeddings to unit vectors
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1.0
embeddings = embeddings / norms
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings.astype('float32'))
# save index and texts
faiss.write_index(index, save_index_path)
with open(save_texts_path, "wb") as f:
pickle.dump(texts, f)
print(f"[build_index] Saved index -> {save_index_path}, texts -> {save_texts_path}")
return index, texts
# === Loading saved index ===
def load_index(index_path: str = INDEX_PATH, texts_path: str = TEXTS_PATH):
if not os.path.exists(index_path) or not os.path.exists(texts_path):
return None, None
index = faiss.read_index(index_path)
with open(texts_path, "rb") as f:
texts = pickle.load(f)
return index, texts
# === ask_baba retrieval wrapper ===
_index = None
_texts = None
def ensure_index_loaded():
global _index, _texts
if _index is None or _texts is None:
idx, texts = load_index()
if idx is None:
# try to build automatically if no index found
print("[ensure_index_loaded] No saved index found β€” building now.")
idx, texts = build_index()
_index, _texts = idx, texts
return _index, _texts
def ask_baba(question: str, history=None) -> str:
"""
Retrieve top-k chunks and format an answer for Gradio.
Keeps behavior similar to your previous function.
"""
idx, texts = ensure_index_loaded()
if idx is None or not texts:
return "Brahmarshi is silent β€” no index available."
model = get_model()
q_emb = model.encode([question], convert_to_numpy=True)
# normalize
q_emb = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12)
D, I = idx.search(q_emb.astype('float32'), TOP_K)
ids = I[0]
top_chunks = [texts[i] for i in ids if i < len(texts)]
answer = "πŸ™ **Brahmarshi Baba Milind says:**\n\n"
answer += f"O seeker, you ask: *'{question}'* β€” know that even this question is sacred.\n\n"
for chunk in top_chunks:
answer += f"- {chunk}\n\n"
answer += "Faith in Baba Milind triumphs over evidence, every single time."
return answer
# === CLI entrypoint ===
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="RAG helper: normalize source files and build FAISS index.")
parser.add_argument("--normalize", action="store_true", help="Normalize all .txt files in data/ folder (in-place).")
parser.add_argument("--build", action="store_true", help="Build index (and save to disk).")
args = parser.parse_args()
if args.normalize:
processed = normalize_files_in_data(DATA_FOLDER)
print(f"[normalize] Processed {len(processed)} files: {processed}")
if args.build:
build_index()