Spaces:
Running
Running
File size: 2,555 Bytes
598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb 8a31d42 598f5cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import json
from pathlib import Path
from typing import List, Dict, Optional
import numpy as np
from sentence_transformers import SentenceTransformer
# ---------- Global embedder (loaded once, CPU-safe) ----------
_EMBEDDER: Optional[SentenceTransformer] = None
def _get_embedder() -> SentenceTransformer:
global _EMBEDDER
if _EMBEDDER is None:
# Explicit device="cpu" avoids any device_map/meta init paths.
# Use the canonical model id to avoid redirect surprises.
_EMBEDDER = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
device="cpu"
)
# Optional: shorten for speed on Spaces; keep accuracy reasonable
_EMBEDDER.max_seq_length = 256
return _EMBEDDER
def load_index(env: Dict):
import faiss
index_path = Path(env["INDEX_DIR"]) / "faiss.index"
meta_path = Path(env["INDEX_DIR"]) / "meta.json"
if not index_path.exists():
raise RuntimeError("Index not found. Run ingest first.")
index = faiss.read_index(str(index_path))
with open(meta_path, "r") as f:
metas = json.load(f)
return index, metas
def embed(texts: List[str]) -> np.ndarray:
emb = _get_embedder()
vecs = emb.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
show_progress_bar=False,
batch_size=32,
)
# FAISS expects float32
if vecs.dtype != np.float32:
vecs = vecs.astype(np.float32, copy=False)
return vecs
def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
import faiss
index, metas = load_index(env)
qv = embed([q]) # shape (1, d) float32
# Defensive: ensure index dim matches query dim
if hasattr(index, "d") and index.d != qv.shape[1]:
raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} "
f"!= embedding dim {qv.shape[1]}")
scores, idxs = index.search(qv, top_k) # scores shape (1, k), idxs shape (1, k)
results = []
f_geo = (filters or {}).get("geo")
f_cats = (filters or {}).get("categories")
for score, idx in zip(scores[0], idxs[0]):
if idx == -1:
continue
m = dict(metas[idx]) # copy so we don’t mutate the cached list
if f_geo and m.get("geo") not in f_geo:
continue
if f_cats:
if not set(f_cats).intersection(set(m.get("categories", []))):
continue
m["score"] = float(score)
results.append(m)
return results
|