import os import pickle import json import numpy as np from typing import List, Dict, Any, Optional, Tuple import faiss from tqdm import tqdm from sentence_transformers import SentenceTransformer, CrossEncoder class VectorStore: def __init__(self, embedding_dir: str = "data/embeddings", model_name: str = "BAAI/bge-small-en-v1.5", reranker_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): self.embedding_dir = embedding_dir self.index = None self.chunk_ids = [] self.chunks = {} # Load embedding model self.model = SentenceTransformer(model_name) # Load reranker model self.reranker = CrossEncoder(reranker_name) # Load or create index self.load_or_create_index() def load_or_create_index(self) -> None: """Load existing index or create a new one.""" index_path = os.path.join(self.embedding_dir, 'faiss_index.pkl') if os.path.exists(index_path): # Load existing index with open(index_path, 'rb') as f: data = pickle.load(f) self.index = data['index'] self.chunk_ids = data['chunk_ids'] self.chunks = data['chunks'] print(f"Loaded existing index with {len(self.chunk_ids)} chunks") else: # Create new index embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') if os.path.exists(embeddings_path): self.create_index() else: print("No embeddings found. Please run the chunker first.") def create_index(self) -> None: """Create FAISS index from embeddings.""" embeddings_path = os.path.join(self.embedding_dir, 'embeddings.pkl') with open(embeddings_path, 'rb') as f: embedding_map = pickle.load(f) # Extract embeddings and chunk IDs chunk_ids = list(embedding_map.keys()) embeddings = np.array([embedding_map[chunk_id]['embedding'] for chunk_id in chunk_ids]) chunks = {chunk_id: embedding_map[chunk_id]['chunk'] for chunk_id in chunk_ids} # Create FAISS index dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings.astype(np.float32)) # Save index and metadata self.index = index self.chunk_ids = chunk_ids self.chunks = chunks # Save to disk with open(os.path.join(self.embedding_dir, 'faiss_index.pkl'), 'wb') as f: pickle.dump({ 'index': index, 'chunk_ids': chunk_ids, 'chunks': chunks }, f) print(f"Created index with {len(chunk_ids)} chunks") def search(self, query: str, k: int = 5, filter_categories: Optional[List[str]] = None, rerank: bool = True) -> List[Dict[str, Any]]: """Search for relevant chunks.""" if self.index is None: print("No index available. Please create an index first.") return [] # Create query embedding query_embedding = self.model.encode([query])[0] # Search index D, I = self.index.search(np.array([query_embedding]).astype(np.float32), min(k * 2, len(self.chunk_ids))) # Get results results = [] for i, idx in enumerate(I[0]): chunk_id = self.chunk_ids[idx] chunk = self.chunks[chunk_id] # Apply category filter if specified if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): continue result = { 'chunk_id': chunk_id, 'score': float(D[0][i]), 'chunk': chunk } results.append(result) # Rerank results if requested if rerank and results: # Prepare pairs for reranking pairs = [(query, result['chunk']['content']) for result in results] # Get reranking scores rerank_scores = self.reranker.predict(pairs) # Update scores and sort for i, score in enumerate(rerank_scores): results[i]['rerank_score'] = float(score) # Sort by rerank score results = sorted(results, key=lambda x: x['rerank_score'], reverse=True) # Limit to k results results = results[:k] return results def hybrid_search(self, query: str, k: int = 5, filter_categories: Optional[List[str]] = None) -> List[Dict[str, Any]]: """Combine dense vector search with BM25-style keyword matching.""" # Get vector search results vector_results = self.search(query, k=k, filter_categories=filter_categories, rerank=False) # Simple keyword matching (simulating BM25) keywords = query.lower().split() # Score all chunks by keyword presence keyword_scores = {} for chunk_id, chunk_data in self.chunks.items(): chunk = chunk_data content = (chunk['title'] + " " + chunk['content']).lower() # Count keyword matches score = sum(content.count(keyword) for keyword in keywords) # Apply category filter if specified if filter_categories and not any(cat in chunk.get('categories', []) for cat in filter_categories): continue keyword_scores[chunk_id] = score # Get top keyword matches keyword_results = sorted( [{'chunk_id': chunk_id, 'score': score, 'chunk': self.chunks[chunk_id]} for chunk_id, score in keyword_scores.items() if score > 0], key=lambda x: x['score'], reverse=True )[:k] # Combine results (remove duplicates) seen_ids = set() combined_results = [] # Add vector results first for result in vector_results: combined_results.append(result) seen_ids.add(result['chunk_id']) # Add keyword results if not already added for result in keyword_results: if result['chunk_id'] not in seen_ids: combined_results.append(result) seen_ids.add(result['chunk_id']) # Limit to k results combined_results = combined_results[:k] # Rerank final results if combined_results: # Prepare pairs for reranking pairs = [(query, result['chunk']['content']) for result in combined_results] # Get reranking scores rerank_scores = self.reranker.predict(pairs) # Update scores and sort for i, score in enumerate(rerank_scores): combined_results[i]['rerank_score'] = float(score) # Sort by rerank score combined_results = sorted(combined_results, key=lambda x: x['rerank_score'], reverse=True) return combined_results # Example usage if __name__ == "__main__": vector_store = VectorStore() results = vector_store.hybrid_search("How do I apply for OPT?") print(f"Found {len(results)} results") for i, result in enumerate(results[:3]): print(f"Result {i+1}: {result['chunk']['title']}") print(f"Score: {result.get('rerank_score', result['score'])}") print(f"Content: {result['chunk']['content'][:100]}...") print()