import json from pathlib import Path from typing import List, Dict, Optional import numpy as np from sentence_transformers import SentenceTransformer # ---------- Global embedder (loaded once, CPU-safe) ---------- _EMBEDDER: Optional[SentenceTransformer] = None def _get_embedder() -> SentenceTransformer: global _EMBEDDER if _EMBEDDER is None: # Explicit device="cpu" avoids any device_map/meta init paths. # Use the canonical model id to avoid redirect surprises. _EMBEDDER = SentenceTransformer( "sentence-transformers/all-MiniLM-L6-v2", device="cpu" ) # Optional: shorten for speed on Spaces; keep accuracy reasonable _EMBEDDER.max_seq_length = 256 return _EMBEDDER def load_index(env: Dict): import faiss index_path = Path(env["INDEX_DIR"]) / "faiss.index" meta_path = Path(env["INDEX_DIR"]) / "meta.json" if not index_path.exists(): raise RuntimeError("Index not found. Run ingest first.") index = faiss.read_index(str(index_path)) with open(meta_path, "r") as f: metas = json.load(f) return index, metas def embed(texts: List[str]) -> np.ndarray: emb = _get_embedder() vecs = emb.encode( texts, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False, batch_size=32, ) # FAISS expects float32 if vecs.dtype != np.float32: vecs = vecs.astype(np.float32, copy=False) return vecs def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]: import faiss index, metas = load_index(env) qv = embed([q]) # shape (1, d) float32 # Defensive: ensure index dim matches query dim if hasattr(index, "d") and index.d != qv.shape[1]: raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} " f"!= embedding dim {qv.shape[1]}") scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k) results = [] f_geo = (filters or {}).get("geo") f_cats = (filters or {}).get("categories") for score, idx in zip(scores[0], idxs[0]): if idx == -1: continue m = dict(metas[idx]) # copy so we don’t mutate the cached list if f_geo and m.get("geo") not in f_geo: continue if f_cats: if not set(f_cats).intersection(set(m.get("categories", []))): continue m["score"] = float(score) results.append(m) return results