Rom89823974978's picture
Resolved errors shown by tests
f868144
"""Dense vector retriever with automatic FAISS index construction."""
from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Sequence, Union
import faiss # type: ignore
import numpy as np
from sentence_transformers import SentenceTransformer
from .base import Context, Retriever
logger = logging.getLogger(__name__)
class DenseRetriever(Retriever):
"""Sentence-Transformers + FAISS ANN search.
* If `faiss_index` does **not** exist, it is built from `doc_store`.
* Embedding model (and its cache location) are configurable.
"""
def __init__(
self,
faiss_index: Union[str, Path],
*,
doc_store: Union[str, Path],
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
embedder_cache: Optional[Union[str, Path]] = None,
device: str = "cpu",
):
self.faiss_index = Path(faiss_index)
self.doc_store = Path(doc_store)
# ------------------------------------------------------------------
# Sentence-Transformers embedder
# ------------------------------------------------------------------
try:
self.embedder = SentenceTransformer(
model_name,
device=device,
cache_folder=str(embedder_cache) if embedder_cache else None,
)
logger.info("Embedder '%s' ready (device=%s)", model_name, device)
except Exception as e:
logger.warning(
"Failed to load SentenceTransformer '%s' (%s). "
"DenseRetriever.retrieve() will return no hits.",
model_name,
e,
)
self.embedder = None
# ------------------------------------------------------------------
# Build FAISS index if absent
# ------------------------------------------------------------------
if not self.faiss_index.exists():
if self.embedder is not None:
try:
logger.info("FAISS index %s missing – building ...", self.faiss_index)
self._build_index()
except Exception as e:
logger.warning(
"Failed to build FAISS index: %s. "
"DenseRetriever.retrieve() will return no hits.",
e,
)
else:
logger.warning("Embedder is None; skipping FAISS index build.")
# Attempt to load the index
try:
self.index = faiss.read_index(str(self.faiss_index))
logger.info("Loaded FAISS index with %d vectors", self.index.ntotal)
except Exception as e:
logger.warning(
"Unable to load FAISS index (%s). DenseRetriever.retrieve() will return no hits.",
e,
)
self.index = None
# Keep doc texts in memory for convenience
self._texts: List[str] = []
try:
with self.doc_store.open() as f:
for line in f:
obj = json.loads(line)
self._texts.append(obj.get("text", ""))
except Exception as e:
logger.warning(
"Failed to load doc_store texts (%s). Retrieved contexts will have empty text.",
e,
)
self._texts = []
# ------------------------------------------------------------------ #
# Public API
# ------------------------------------------------------------------ #
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
if self.index is None or self.embedder is None:
return []
try:
vec = self._embed(query)
vec = np.asarray(vec, dtype="float32")[None, :]
dists, idxs = self.index.search(vec, top_k)
dists, idxs = dists[0], idxs[0]
except Exception as e:
logger.warning("DenseRetriever retrieval failed (%s); returning no hits.", e)
return []
results: List[Context] = []
for i, score in zip(idxs, dists):
if i == -1:
continue
# If metric is L2, higher distance means worse; invert
if self.index.metric_type == faiss.METRIC_L2:
score = -score
text = self._texts[i] if (0 <= i < len(self._texts)) else ""
results.append(Context(id=str(i), text=text, score=float(score)))
results.sort(key=lambda c: c.score, reverse=True)
return results
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
def _embed(self, text: str) -> Sequence[float]:
return self.embedder.encode(text, normalize_embeddings=True).tolist()
def _build_index(self):
"""Read all texts, embed them, and write a FAISS IP index."""
logger.info("Reading documents from %s", self.doc_store)
ids: List[int] = []
vectors: List[str] = []
with self.doc_store.open() as f:
for line in f:
obj = json.loads(line)
ids.append(int(obj["id"]))
vectors.append(obj["text"])
logger.info("Embedding %d documents ...", len(ids))
embs = self.embedder.encode(
vectors,
batch_size=128,
show_progress_bar=True,
normalize_embeddings=True,
).astype("float32")
logger.info("Creating FAISS index (Inner-Product)")
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
faiss.write_index(index, str(self.faiss_index))
logger.info("Saved FAISS index to %s", self.faiss_index)