Spaces:
Running
Running
import json | |
from pathlib import Path | |
from typing import List, Dict, Optional | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
# ---------- Global embedder (loaded once, CPU-safe) ---------- | |
_EMBEDDER: Optional[SentenceTransformer] = None | |
def _get_embedder() -> SentenceTransformer: | |
global _EMBEDDER | |
if _EMBEDDER is None: | |
# Explicit device="cpu" avoids any device_map/meta init paths. | |
# Use the canonical model id to avoid redirect surprises. | |
_EMBEDDER = SentenceTransformer( | |
"sentence-transformers/all-MiniLM-L6-v2", | |
device="cpu" | |
) | |
# Optional: shorten for speed on Spaces; keep accuracy reasonable | |
_EMBEDDER.max_seq_length = 256 | |
return _EMBEDDER | |
def load_index(env: Dict): | |
import faiss | |
index_path = Path(env["INDEX_DIR"]) / "faiss.index" | |
meta_path = Path(env["INDEX_DIR"]) / "meta.json" | |
if not index_path.exists(): | |
raise RuntimeError("Index not found. Run ingest first.") | |
index = faiss.read_index(str(index_path)) | |
with open(meta_path, "r") as f: | |
metas = json.load(f) | |
return index, metas | |
def embed(texts: List[str]) -> np.ndarray: | |
emb = _get_embedder() | |
vecs = emb.encode( | |
texts, | |
convert_to_numpy=True, | |
normalize_embeddings=True, | |
show_progress_bar=False, | |
batch_size=32, | |
) | |
# FAISS expects float32 | |
if vecs.dtype != np.float32: | |
vecs = vecs.astype(np.float32, copy=False) | |
return vecs | |
def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]: | |
import faiss | |
index, metas = load_index(env) | |
qv = embed([q]) # shape (1, d) float32 | |
# Defensive: ensure index dim matches query dim | |
if hasattr(index, "d") and index.d != qv.shape[1]: | |
raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} " | |
f"!= embedding dim {qv.shape[1]}") | |
scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k) | |
results = [] | |
f_geo = (filters or {}).get("geo") | |
f_cats = (filters or {}).get("categories") | |
for score, idx in zip(scores[0], idxs[0]): | |
if idx == -1: | |
continue | |
m = dict(metas[idx]) # copy so we don’t mutate the cached list | |
if f_geo and m.get("geo") not in f_geo: | |
continue | |
if f_cats: | |
if not set(f_cats).intersection(set(m.get("categories", []))): | |
continue | |
m["score"] = float(score) | |
results.append(m) | |
return results | |