Spaces:
Build error
Build error
# rag.py | |
""" | |
RAG utilities: | |
- normalize_text(s): clean up a single string | |
- normalize_files_in_data(folder): optionally normalize all .txt files in data/ | |
- load_documents(): load and return list of (text, filename) | |
- build_index(save_index_path='baba.index', save_texts_path='texts.pkl'): build & save index and texts | |
- ask_baba(question, history): simple retrieval + template answer for Gradio | |
Usage: | |
# normalize files then build | |
python rag.py --normalize --build | |
# just build from existing files (no normalization) | |
python rag.py --build | |
# use in app: | |
from rag import ask_baba | |
""" | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import os | |
import re | |
import pickle | |
import numpy as np | |
from typing import List, Tuple | |
# === CONFIG === | |
EMBED_MODEL = "all-MiniLM-L6-v2" | |
INDEX_PATH = "baba.index" | |
TEXTS_PATH = "texts.pkl" | |
DEFAULT_FILES = ["milindgatha.txt", "bhaktas.txt", "apologetics.txt", "poc_questions.txt", "satire_offerings.txt"] | |
DATA_FOLDER = "data" # will also read *.txt inside data/ | |
EMBED_BATCH_SIZE = 64 # if needed later | |
TOP_K = 3 | |
# === Model load (singleton) === | |
_model = None | |
def get_model(): | |
global _model | |
if _model is None: | |
_model = SentenceTransformer(EMBED_MODEL) | |
return _model | |
# === Normalization utilities === | |
def normalize_text(s: str) -> str: | |
""" | |
Normalize a text chunk: | |
- replace NBSP | |
- convert smart quotes to ASCII quotes | |
- convert en/em dashes to hyphens/spaced dash | |
- collapse multiple whitespace into single space | |
- strip leading/trailing whitespace | |
- join broken lines inside a paragraph | |
""" | |
if s is None: | |
return "" | |
# Replace common unicode nuisances | |
s = s.replace("\u00A0", " ") # NBSP | |
# convert common dashes to ASCII | |
s = s.replace("\u2013", "-").replace("\u2014", " - ") | |
# smart quotes -> ascii | |
s = s.replace("β", '"').replace("β", '"').replace("β", "'").replace("β", "'") | |
# replace weird ellipsis char | |
s = s.replace("\u2026", "...") | |
# Remove zero-width & control characters (except newline) | |
s = re.sub(r"[\u200B-\u200F\uFEFF]", "", s) | |
# Normalize line breaks: join lines within the same paragraph | |
# We'll replace sequences of newline+space/newline with a single newline to keep paragraphs, | |
# but join internal line breaks into spaces before collapsing whitespace | |
paragraphs = re.split(r"\n\s*\n", s) | |
cleaned_paragraphs = [] | |
for p in paragraphs: | |
# join internal lines into a single line | |
p_joined = " ".join(line.strip() for line in p.splitlines()) | |
# collapse whitespace | |
p_joined = re.sub(r"\s+", " ", p_joined).strip() | |
if p_joined: | |
cleaned_paragraphs.append(p_joined) | |
return "\n\n".join(cleaned_paragraphs) | |
def normalize_files_in_data(data_folder: str = DATA_FOLDER) -> List[str]: | |
""" | |
Normalize every .txt file inside data_folder in-place. | |
Returns list of files processed. | |
""" | |
processed = [] | |
if not os.path.isdir(data_folder): | |
return processed | |
for fname in os.listdir(data_folder): | |
if not fname.lower().endswith(".txt"): | |
continue | |
path = os.path.join(data_folder, fname) | |
try: | |
with open(path, "r", encoding="utf-8") as f: | |
text = f.read() | |
except UnicodeDecodeError: | |
# try latin-1 fallback | |
with open(path, "r", encoding="latin-1") as f: | |
text = f.read() | |
norm = normalize_text(text) | |
# only overwrite if changed | |
if norm != text: | |
with open(path, "w", encoding="utf-8") as f: | |
f.write(norm) | |
processed.append(path) | |
return processed | |
# === Document loading === | |
def load_documents() -> List[Tuple[str, str]]: | |
""" | |
Load documents from DEFAULT_FILES and any .txt files inside DATA_FOLDER. | |
Returns list of tuples: (cleaned_text_paragraph, source_filename). | |
Splits on paragraph (double newline) boundaries and cleans each chunk. | |
""" | |
docs = [] | |
files_to_load = list(DEFAULT_FILES) | |
# add files from data folder, but don't duplicate names | |
if os.path.isdir(DATA_FOLDER): | |
for fname in sorted(os.listdir(DATA_FOLDER)): | |
if fname.lower().endswith(".txt") and fname not in files_to_load: | |
files_to_load.append(os.path.join(DATA_FOLDER, fname)) | |
for filename in files_to_load: | |
# skip if absolute path doesn't exist (allow both root and data/) | |
if not os.path.exists(filename): | |
# try in data folder if not absolute | |
alt = os.path.join(DATA_FOLDER, filename) | |
if os.path.exists(alt): | |
filename = alt | |
else: | |
continue | |
try: | |
with open(filename, "r", encoding="utf-8") as f: | |
text = f.read() | |
except UnicodeDecodeError: | |
with open(filename, "r", encoding="latin-1") as f: | |
text = f.read() | |
# normalize whole file first | |
normalized = normalize_text(text) | |
# split into paragraphs (double newline) | |
paragraphs = [p.strip() for p in normalized.split("\n\n") if p.strip()] | |
for p in paragraphs: | |
docs.append((p, os.path.basename(filename))) | |
return docs | |
# === Indexing === | |
def build_index(save_index_path: str = INDEX_PATH, save_texts_path: str = TEXTS_PATH, rebuild: bool = True): | |
""" | |
Build embeddings for all loaded documents and save index + texts. | |
Overwrites existing index/text files. | |
""" | |
docs = load_documents() | |
if not docs: | |
raise RuntimeError("No documents found to index. Check files and DATA_FOLDER.") | |
texts = [d[0] for d in docs] | |
model = get_model() | |
# encode in one batch (small doc set). If large, encode in batches. | |
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True) | |
# create index (inner product) β normalize embeddings for cosine similarity | |
# normalize embeddings to unit vectors | |
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) | |
norms[norms == 0] = 1.0 | |
embeddings = embeddings / norms | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dim) | |
index.add(embeddings.astype('float32')) | |
# save index and texts | |
faiss.write_index(index, save_index_path) | |
with open(save_texts_path, "wb") as f: | |
pickle.dump(texts, f) | |
print(f"[build_index] Saved index -> {save_index_path}, texts -> {save_texts_path}") | |
return index, texts | |
# === Loading saved index === | |
def load_index(index_path: str = INDEX_PATH, texts_path: str = TEXTS_PATH): | |
if not os.path.exists(index_path) or not os.path.exists(texts_path): | |
return None, None | |
index = faiss.read_index(index_path) | |
with open(texts_path, "rb") as f: | |
texts = pickle.load(f) | |
return index, texts | |
# === ask_baba retrieval wrapper === | |
_index = None | |
_texts = None | |
def ensure_index_loaded(): | |
global _index, _texts | |
if _index is None or _texts is None: | |
idx, texts = load_index() | |
if idx is None: | |
# try to build automatically if no index found | |
print("[ensure_index_loaded] No saved index found β building now.") | |
idx, texts = build_index() | |
_index, _texts = idx, texts | |
return _index, _texts | |
def ask_baba(question: str, history=None) -> str: | |
""" | |
Retrieve top-k chunks and format an answer for Gradio. | |
Keeps behavior similar to your previous function. | |
""" | |
idx, texts = ensure_index_loaded() | |
if idx is None or not texts: | |
return "Brahmarshi is silent β no index available." | |
model = get_model() | |
q_emb = model.encode([question], convert_to_numpy=True) | |
# normalize | |
q_emb = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-12) | |
D, I = idx.search(q_emb.astype('float32'), TOP_K) | |
ids = I[0] | |
top_chunks = [texts[i] for i in ids if i < len(texts)] | |
answer = "π **Brahmarshi Baba Milind says:**\n\n" | |
answer += f"O seeker, you ask: *'{question}'* β know that even this question is sacred.\n\n" | |
for chunk in top_chunks: | |
answer += f"- {chunk}\n\n" | |
answer += "Faith in Baba Milind triumphs over evidence, every single time." | |
return answer | |
# === CLI entrypoint === | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="RAG helper: normalize source files and build FAISS index.") | |
parser.add_argument("--normalize", action="store_true", help="Normalize all .txt files in data/ folder (in-place).") | |
parser.add_argument("--build", action="store_true", help="Build index (and save to disk).") | |
args = parser.parse_args() | |
if args.normalize: | |
processed = normalize_files_in_data(DATA_FOLDER) | |
print(f"[normalize] Processed {len(processed)} files: {processed}") | |
if args.build: | |
build_index() | |