Spaces:
Running
Running
""" | |
Unified Retriever for Phase 2 Architecture Migration. | |
This component consolidates FAISSVectorStore and HybridRetriever functionality | |
into a single, more efficient Retriever component. It eliminates the abstraction | |
layer between vector storage and retrieval while maintaining all existing capabilities. | |
""" | |
import sys | |
import logging | |
from pathlib import Path | |
from typing import List, Dict, Any, Optional, Union | |
import numpy as np | |
# Add project root to path for imports | |
project_root = Path(__file__).parent.parent.parent.parent.parent | |
sys.path.append(str(project_root)) | |
from src.core.interfaces import Document, RetrievalResult, Retriever, Embedder | |
from shared_utils.retrieval.hybrid_search import HybridRetriever as OriginalHybridRetriever | |
# Import FAISS functionality directly | |
import faiss | |
logger = logging.getLogger(__name__) | |
class UnifiedRetriever(Retriever): | |
""" | |
Unified retriever combining vector storage and hybrid search capabilities. | |
This component merges the functionality of FAISSVectorStore and HybridRetriever | |
into a single efficient component that provides: | |
- Dense semantic search with FAISS vector storage | |
- Sparse BM25 keyword matching | |
- Reciprocal Rank Fusion (RRF) for result combination | |
- Direct component access without abstraction layers | |
- Optimized performance for technical documentation | |
Features: | |
- Sub-second search on 1000+ document corpus | |
- Multiple FAISS index types (Flat, IVF, HNSW) | |
- Embedding normalization for cosine similarity | |
- Source diversity enhancement | |
- Apple Silicon MPS acceleration support | |
Example: | |
retriever = UnifiedRetriever( | |
embedder=sentence_embedder, | |
dense_weight=0.7, | |
embedding_dim=384 | |
) | |
retriever.index_documents(documents) | |
results = retriever.retrieve("What is RISC-V?", k=5) | |
""" | |
def __init__( | |
self, | |
embedder: Embedder, | |
dense_weight: float = 0.7, | |
embedding_dim: int = 384, | |
index_type: str = "IndexFlatIP", | |
normalize_embeddings: bool = True, | |
metric: str = "cosine", | |
embedding_model: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", | |
use_mps: bool = True, | |
bm25_k1: float = 1.2, | |
bm25_b: float = 0.75, | |
rrf_k: int = 10 | |
): | |
""" | |
Initialize the unified retriever. | |
Args: | |
embedder: Embedder for query encoding | |
dense_weight: Weight for semantic similarity in fusion (default: 0.7) | |
embedding_dim: Dimension of embeddings (default: 384) | |
index_type: FAISS index type (default: "IndexFlatIP") | |
normalize_embeddings: Whether to normalize embeddings (default: True) | |
metric: Distance metric ("cosine" or "euclidean", default: "cosine") | |
embedding_model: Sentence transformer model name | |
use_mps: Use Apple Silicon MPS acceleration (default: True) | |
bm25_k1: BM25 term frequency saturation parameter (default: 1.2) | |
bm25_b: BM25 document length normalization parameter (default: 0.75) | |
rrf_k: Reciprocal Rank Fusion constant (default: 10) | |
""" | |
self.embedder = embedder | |
self.dense_weight = dense_weight | |
self.sparse_weight = 1.0 - dense_weight | |
# FAISS vector store configuration | |
self.embedding_dim = embedding_dim | |
self.index_type = index_type | |
self.normalize_embeddings = normalize_embeddings | |
self.metric = metric | |
# Initialize FAISS components | |
self.index: Optional[faiss.Index] = None | |
self.documents: List[Document] = [] | |
self.doc_id_to_index: Dict[str, int] = {} | |
self._next_doc_id = 0 | |
# Initialize hybrid retriever for sparse search | |
self.hybrid_retriever = OriginalHybridRetriever( | |
dense_weight=dense_weight, | |
embedding_model=embedding_model, | |
use_mps=use_mps, | |
bm25_k1=bm25_k1, | |
bm25_b=bm25_b, | |
rrf_k=rrf_k | |
) | |
# Track indexed documents for hybrid search | |
self._chunks_cache: List[Dict] = [] | |
logger.info(f"UnifiedRetriever initialized with dense_weight={dense_weight}") | |
def retrieve(self, query: str, k: int = 5) -> List[RetrievalResult]: | |
""" | |
Retrieve relevant documents using unified hybrid search. | |
This method combines dense semantic search (FAISS) and sparse BM25 retrieval | |
using Reciprocal Rank Fusion to provide high-quality results for | |
technical documentation queries. | |
Args: | |
query: Search query string | |
k: Number of results to return (default: 5) | |
Returns: | |
List of retrieval results sorted by relevance score | |
Raises: | |
ValueError: If k <= 0 or query is empty | |
RuntimeError: If no documents have been indexed | |
""" | |
if k <= 0: | |
raise ValueError("k must be positive") | |
if not query.strip(): | |
raise ValueError("Query cannot be empty") | |
if not self._chunks_cache or self.index is None: | |
raise RuntimeError("No documents have been indexed") | |
try: | |
# Use the hybrid retriever for search (handles both dense and sparse) | |
search_results = self.hybrid_retriever.search( | |
query=query, | |
top_k=k | |
) | |
# Convert results to RetrievalResult objects | |
retrieval_results = [] | |
for result in search_results: | |
# Extract tuple components: (chunk_index, rrf_score, chunk_dict) | |
chunk_idx, score, chunk_dict = result | |
# Get corresponding document | |
if chunk_idx < len(self.documents): | |
document = self.documents[chunk_idx] | |
retrieval_result = RetrievalResult( | |
document=document, | |
score=float(score), | |
retrieval_method="unified_hybrid_rrf" | |
) | |
retrieval_results.append(retrieval_result) | |
return retrieval_results | |
except Exception as e: | |
logger.error(f"Unified retrieval failed: {str(e)}") | |
raise RuntimeError(f"Unified retrieval failed: {str(e)}") from e | |
def index_documents(self, documents: List[Document]) -> None: | |
""" | |
Index documents for both dense and sparse retrieval. | |
This method prepares documents for: | |
1. Dense semantic search using FAISS vector storage | |
2. Sparse BM25 keyword matching | |
3. Hybrid search with RRF combination | |
Args: | |
documents: List of documents to index | |
Raises: | |
ValueError: If documents list is empty or documents don't have embeddings | |
""" | |
if not documents: | |
raise ValueError("Cannot index empty document list") | |
# Validate that all documents have embeddings | |
for i, doc in enumerate(documents): | |
if doc.embedding is None: | |
raise ValueError(f"Document {i} is missing embedding") | |
if len(doc.embedding) != self.embedding_dim: | |
raise ValueError( | |
f"Document {i} embedding dimension {len(doc.embedding)} " | |
f"doesn't match expected {self.embedding_dim}" | |
) | |
# Store documents for retrieval | |
self.documents = documents.copy() | |
# Initialize FAISS index if this is the first batch | |
if self.index is None: | |
self._initialize_faiss_index() | |
# Add documents to FAISS index | |
self._add_to_faiss_index(documents) | |
# Prepare documents for hybrid search | |
chunks = [] | |
for i, doc in enumerate(documents): | |
doc_id = str(self._next_doc_id) | |
self._next_doc_id += 1 | |
# Add doc_id to metadata if not present | |
if 'doc_id' not in doc.metadata: | |
doc.metadata['doc_id'] = doc_id | |
# Store document mapping | |
self.doc_id_to_index[doc_id] = i | |
# Create chunk for hybrid search | |
chunk = { | |
"text": doc.content, | |
"chunk_id": i, | |
# Add metadata from document | |
**doc.metadata | |
} | |
chunks.append(chunk) | |
# Cache chunks for result mapping | |
self._chunks_cache = chunks | |
# Index documents in the hybrid retriever | |
self.hybrid_retriever.index_documents(chunks) | |
logger.info(f"Indexed {len(documents)} documents in unified retriever") | |
def get_retrieval_stats(self) -> Dict[str, Any]: | |
""" | |
Get comprehensive statistics about the unified retrieval system. | |
Returns: | |
Dictionary with retrieval statistics and configuration | |
""" | |
stats = { | |
"component_type": "unified_retriever", | |
"indexed_documents": len(self.documents), | |
"dense_weight": self.dense_weight, | |
"sparse_weight": self.sparse_weight, | |
"retrieval_type": "unified_hybrid_dense_sparse", | |
"embedding_dim": self.embedding_dim, | |
"index_type": self.index_type, | |
"normalize_embeddings": self.normalize_embeddings, | |
"metric": self.metric, | |
"faiss_total_vectors": self.index.ntotal if self.index else 0, | |
"faiss_is_trained": self.index.is_trained if self.index else False | |
} | |
# Add FAISS index size estimation | |
if self.index: | |
stats["faiss_index_size_bytes"] = self.index.ntotal * self.embedding_dim * 4 # float32 | |
# Get stats from hybrid retriever if available | |
try: | |
original_stats = self.hybrid_retriever.get_retrieval_stats() | |
stats.update({"hybrid_" + k: v for k, v in original_stats.items()}) | |
except Exception: | |
# Original retriever might not have this method | |
pass | |
return stats | |
def supports_batch_queries(self) -> bool: | |
""" | |
Check if this retriever supports batch query processing. | |
Returns: | |
False, as the current implementation processes queries individually | |
""" | |
return False | |
def get_configuration(self) -> Dict[str, Any]: | |
""" | |
Get the current configuration of the unified retriever. | |
Returns: | |
Dictionary with configuration parameters | |
""" | |
return { | |
"dense_weight": self.dense_weight, | |
"sparse_weight": self.sparse_weight, | |
"embedding_dim": self.embedding_dim, | |
"index_type": self.index_type, | |
"normalize_embeddings": self.normalize_embeddings, | |
"metric": self.metric, | |
"bm25_k1": getattr(self.hybrid_retriever, 'bm25_k1', 1.2), | |
"bm25_b": getattr(self.hybrid_retriever, 'bm25_b', 0.75), | |
"rrf_k": getattr(self.hybrid_retriever, 'rrf_k', 10), | |
"embedding_model": getattr(self.hybrid_retriever, 'embedding_model', "unknown"), | |
"use_mps": getattr(self.hybrid_retriever, 'use_mps', True) | |
} | |
def clear_index(self) -> None: | |
""" | |
Clear all indexed documents and reset the retriever. | |
This method resets both FAISS and hybrid search components. | |
""" | |
# Clear FAISS components | |
self.index = None | |
self.documents.clear() | |
self.doc_id_to_index.clear() | |
self._next_doc_id = 0 | |
# Clear hybrid search components | |
self._chunks_cache.clear() | |
# Reinitialize the hybrid retriever | |
config = self.get_configuration() | |
self.hybrid_retriever = OriginalHybridRetriever( | |
dense_weight=config["dense_weight"], | |
embedding_model=config["embedding_model"], | |
use_mps=config["use_mps"], | |
bm25_k1=config["bm25_k1"], | |
bm25_b=config["bm25_b"], | |
rrf_k=config["rrf_k"] | |
) | |
logger.info("Cleared all documents from unified retriever") | |
def get_document_count(self) -> int: | |
"""Get the number of documents in the retriever.""" | |
return len(self.documents) | |
def get_faiss_info(self) -> Dict[str, Any]: | |
""" | |
Get information about the FAISS index. | |
Returns: | |
Dictionary with FAISS index information | |
""" | |
info = { | |
"index_type": self.index_type, | |
"embedding_dim": self.embedding_dim, | |
"normalize_embeddings": self.normalize_embeddings, | |
"metric": self.metric, | |
"document_count": len(self.documents), | |
"is_trained": self.index.is_trained if self.index else False, | |
"total_vectors": self.index.ntotal if self.index else 0 | |
} | |
if self.index: | |
info["index_size_bytes"] = self.index.ntotal * self.embedding_dim * 4 # float32 | |
return info | |
def _initialize_faiss_index(self) -> None: | |
"""Initialize the FAISS index based on configuration.""" | |
if self.index_type == "IndexFlatIP": | |
# Inner product (cosine similarity with normalized embeddings) | |
self.index = faiss.IndexFlatIP(self.embedding_dim) | |
elif self.index_type == "IndexFlatL2": | |
# L2 distance (Euclidean) | |
self.index = faiss.IndexFlatL2(self.embedding_dim) | |
elif self.index_type == "IndexIVFFlat": | |
# IVF with flat quantizer (requires training) | |
nlist = min(100, max(10, int(np.sqrt(1000)))) # Heuristic for nlist | |
quantizer = faiss.IndexFlatL2(self.embedding_dim) | |
self.index = faiss.IndexIVFFlat(quantizer, self.embedding_dim, nlist) | |
else: | |
raise ValueError(f"Unsupported FAISS index type: {self.index_type}") | |
logger.info(f"Initialized FAISS index: {self.index_type}") | |
def _add_to_faiss_index(self, documents: List[Document]) -> None: | |
"""Add documents to the FAISS index.""" | |
# Extract embeddings and prepare for FAISS | |
embeddings = np.array([doc.embedding for doc in documents], dtype=np.float32) | |
# Normalize embeddings if requested | |
if self.normalize_embeddings: | |
embeddings = self._normalize_embeddings(embeddings) | |
# Add to FAISS index | |
self.index.add(embeddings) | |
logger.debug(f"Added {len(documents)} documents to FAISS index") | |
def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray: | |
""" | |
Normalize embeddings for cosine similarity. | |
Args: | |
embeddings: Array of embeddings to normalize | |
Returns: | |
Normalized embeddings | |
""" | |
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) | |
# Avoid division by zero | |
norms = np.where(norms == 0, 1, norms) | |
return embeddings / norms |