File size: 2,555 Bytes
598f5cb
 
8a31d42
598f5cb
 
8a31d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
598f5cb
8a31d42
598f5cb
 
 
 
 
8a31d42
 
598f5cb
 
8a31d42
 
 
 
 
 
 
 
 
 
 
 
 
598f5cb
 
8a31d42
598f5cb
8a31d42
 
 
 
 
 
 
 
 
598f5cb
8a31d42
 
 
598f5cb
8a31d42
 
 
 
 
 
 
598f5cb
 
 
8a31d42
598f5cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json
from pathlib import Path
from typing import List, Dict, Optional
import numpy as np

from sentence_transformers import SentenceTransformer

# ---------- Global embedder (loaded once, CPU-safe) ----------
_EMBEDDER: Optional[SentenceTransformer] = None

def _get_embedder() -> SentenceTransformer:
    global _EMBEDDER
    if _EMBEDDER is None:
        # Explicit device="cpu" avoids any device_map/meta init paths.
        # Use the canonical model id to avoid redirect surprises.
        _EMBEDDER = SentenceTransformer(
            "sentence-transformers/all-MiniLM-L6-v2",
            device="cpu"
        )
        # Optional: shorten for speed on Spaces; keep accuracy reasonable
        _EMBEDDER.max_seq_length = 256
    return _EMBEDDER

def load_index(env: Dict):
    import faiss
    index_path = Path(env["INDEX_DIR"]) / "faiss.index"
    meta_path = Path(env["INDEX_DIR"]) / "meta.json"
    if not index_path.exists():
        raise RuntimeError("Index not found. Run ingest first.")
    index = faiss.read_index(str(index_path))
    with open(meta_path, "r") as f:
        metas = json.load(f)
    return index, metas

def embed(texts: List[str]) -> np.ndarray:
    emb = _get_embedder()
    vecs = emb.encode(
        texts,
        convert_to_numpy=True,
        normalize_embeddings=True,
        show_progress_bar=False,
        batch_size=32,
    )
    # FAISS expects float32
    if vecs.dtype != np.float32:
        vecs = vecs.astype(np.float32, copy=False)
    return vecs

def search(q: str, env: Dict, top_k: int = 15, filters: Dict = None) -> List[Dict]:
    import faiss
    index, metas = load_index(env)

    qv = embed([q])  # shape (1, d) float32
    # Defensive: ensure index dim matches query dim
    if hasattr(index, "d") and index.d != qv.shape[1]:
        raise RuntimeError(f"FAISS index dim {getattr(index, 'd', '?')} "
                           f"!= embedding dim {qv.shape[1]}")

    scores, idxs = index.search(qv, top_k)  # scores shape (1, k), idxs shape (1, k)

    results = []
    f_geo = (filters or {}).get("geo")
    f_cats = (filters or {}).get("categories")

    for score, idx in zip(scores[0], idxs[0]):
        if idx == -1:
            continue
        m = dict(metas[idx])  # copy so we don’t mutate the cached list
        if f_geo and m.get("geo") not in f_geo:
            continue
        if f_cats:
            if not set(f_cats).intersection(set(m.get("categories", []))):
                continue
        m["score"] = float(score)
        results.append(m)

    return results