grants-rag / app /search.py
michaellupo74's picture
Update app/search.py
8a31d42 verified
import json
from pathlib import Path
from typing import List, Dict, Optional
import numpy as np
from sentence_transformers import SentenceTransformer
# ---------- Global embedder (loaded once, CPU-safe) ----------
_EMBEDDER: Optional[SentenceTransformer] = None
def _get_embedder() -> SentenceTransformer:
global _EMBEDDER
if _EMBEDDER is None:
# Explicit device="cpu" avoids any device_map/meta init paths.
# Use the canonical model id to avoid redirect surprises.
_EMBEDDER = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
device="cpu"
)
# Optional: shorten for speed on Spaces; keep accuracy reasonable
_EMBEDDER.max_seq_length = 256
return _EMBEDDER
def load_index(env: Dict):
import faiss
index_path = Path(env["INDEX_DIR"]) / "faiss.index"
meta_path = Path(env["INDEX_DIR"]) / "meta.json"
if not index_path.exists():
raise RuntimeError("Index not found. Run ingest first.")
index = faiss.read_index(str(index_path))
with open(meta_path, "r") as f:
metas = json.load(f)
return index, metas
def embed(texts: List[str]) -> np.ndarray:
emb = _get_embedder()
vecs = emb.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
show_progress_bar=False,
batch_size=32,
)
# FAISS expects float32
if vecs.dtype != np.float32:
vecs = vecs.astype(np.float32, copy=False)
return vecs
def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
import faiss
index, metas = load_index(env)
qv = embed([q]) # shape (1, d) float32
# Defensive: ensure index dim matches query dim
if hasattr(index, "d") and index.d != qv.shape[1]:
raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} "
f"!= embedding dim {qv.shape[1]}")
scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k)
results = []
f_geo = (filters or {}).get("geo")
f_cats = (filters or {}).get("categories")
for score, idx in zip(scores[0], idxs[0]):
if idx == -1:
continue
m = dict(metas[idx]) # copy so we don’t mutate the cached list
if f_geo and m.get("geo") not in f_geo:
continue
if f_cats:
if not set(f_cats).intersection(set(m.get("categories", []))):
continue
m["score"] = float(score)
results.append(m)
return results