Spaces:
Sleeping
Sleeping
"""Dense vector retriever with automatic FAISS index construction.""" | |
from __future__ import annotations | |
import json | |
import logging | |
import os | |
from pathlib import Path | |
from typing import List, Optional, Sequence, Union | |
import faiss # type: ignore | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from .base import Context, Retriever | |
logger = logging.getLogger(__name__) | |
class DenseRetriever(Retriever): | |
"""Sentence-Transformers + FAISS ANN search. | |
* If `faiss_index` does **not** exist, it is built from `doc_store`. | |
* Embedding model (and its cache location) are configurable. | |
""" | |
def __init__( | |
self, | |
faiss_index: Union[str, Path], | |
*, | |
doc_store: Union[str, Path], | |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | |
embedder_cache: Optional[Union[str, Path]] = None, | |
device: str = "cpu", | |
): | |
self.faiss_index = Path(faiss_index) | |
self.doc_store = Path(doc_store) | |
# ------------------------------------------------------------------ | |
# Sentence-Transformers embedder | |
# ------------------------------------------------------------------ | |
try: | |
self.embedder = SentenceTransformer( | |
model_name, | |
device=device, | |
cache_folder=str(embedder_cache) if embedder_cache else None, | |
) | |
logger.info("Embedder '%s' ready (device=%s)", model_name, device) | |
except Exception as e: | |
logger.warning( | |
"Failed to load SentenceTransformer '%s' (%s). " | |
"DenseRetriever.retrieve() will return no hits.", | |
model_name, | |
e, | |
) | |
self.embedder = None | |
# ------------------------------------------------------------------ | |
# Build FAISS index if absent | |
# ------------------------------------------------------------------ | |
if not self.faiss_index.exists(): | |
if self.embedder is not None: | |
try: | |
logger.info("FAISS index %s missing – building ...", self.faiss_index) | |
self._build_index() | |
except Exception as e: | |
logger.warning( | |
"Failed to build FAISS index: %s. " | |
"DenseRetriever.retrieve() will return no hits.", | |
e, | |
) | |
else: | |
logger.warning("Embedder is None; skipping FAISS index build.") | |
# Attempt to load the index | |
try: | |
self.index = faiss.read_index(str(self.faiss_index)) | |
logger.info("Loaded FAISS index with %d vectors", self.index.ntotal) | |
except Exception as e: | |
logger.warning( | |
"Unable to load FAISS index (%s). DenseRetriever.retrieve() will return no hits.", | |
e, | |
) | |
self.index = None | |
# Keep doc texts in memory for convenience | |
self._texts: List[str] = [] | |
try: | |
with self.doc_store.open() as f: | |
for line in f: | |
obj = json.loads(line) | |
self._texts.append(obj.get("text", "")) | |
except Exception as e: | |
logger.warning( | |
"Failed to load doc_store texts (%s). Retrieved contexts will have empty text.", | |
e, | |
) | |
self._texts = [] | |
# ------------------------------------------------------------------ # | |
# Public API | |
# ------------------------------------------------------------------ # | |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]: | |
if self.index is None or self.embedder is None: | |
return [] | |
try: | |
vec = self._embed(query) | |
vec = np.asarray(vec, dtype="float32")[None, :] | |
dists, idxs = self.index.search(vec, top_k) | |
dists, idxs = dists[0], idxs[0] | |
except Exception as e: | |
logger.warning("DenseRetriever retrieval failed (%s); returning no hits.", e) | |
return [] | |
results: List[Context] = [] | |
for i, score in zip(idxs, dists): | |
if i == -1: | |
continue | |
# If metric is L2, higher distance means worse; invert | |
if self.index.metric_type == faiss.METRIC_L2: | |
score = -score | |
text = self._texts[i] if (0 <= i < len(self._texts)) else "" | |
results.append(Context(id=str(i), text=text, score=float(score))) | |
results.sort(key=lambda c: c.score, reverse=True) | |
return results | |
# ------------------------------------------------------------------ # | |
# Internal helpers | |
# ------------------------------------------------------------------ # | |
def _embed(self, text: str) -> Sequence[float]: | |
return self.embedder.encode(text, normalize_embeddings=True).tolist() | |
def _build_index(self): | |
"""Read all texts, embed them, and write a FAISS IP index.""" | |
logger.info("Reading documents from %s", self.doc_store) | |
ids: List[int] = [] | |
vectors: List[str] = [] | |
with self.doc_store.open() as f: | |
for line in f: | |
obj = json.loads(line) | |
ids.append(int(obj["id"])) | |
vectors.append(obj["text"]) | |
logger.info("Embedding %d documents ...", len(ids)) | |
embs = self.embedder.encode( | |
vectors, | |
batch_size=128, | |
show_progress_bar=True, | |
normalize_embeddings=True, | |
).astype("float32") | |
logger.info("Creating FAISS index (Inner-Product)") | |
index = faiss.IndexFlatIP(embs.shape[1]) | |
index.add(embs) | |
faiss.write_index(index, str(self.faiss_index)) | |
logger.info("Saved FAISS index to %s", self.faiss_index) | |