#!/usr/bin/env python3 """ Vector Store Manager for Maternal Health RAG Chatbot Uses FAISS with the optimal all-MiniLM-L6-v2 embedding model """ import json import numpy as np import faiss from pathlib import Path from typing import List, Dict, Any, Tuple, Optional import logging from sentence_transformers import SentenceTransformer import pickle import time from dataclasses import dataclass # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class SearchResult: """Container for search results""" content: str score: float metadata: Dict[str, Any] chunk_index: int source_document: str chunk_type: str clinical_importance: float class MaternalHealthVectorStore: """Vector store for maternal health guidelines with clinical context filtering""" def __init__(self, vector_store_dir: str = "vector_store", embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", chunks_dir: str = "comprehensive_chunks"): self.vector_store_dir = Path(vector_store_dir) self.vector_store_dir.mkdir(exist_ok=True) self.chunks_dir = Path(chunks_dir) self.embedding_model_name = embedding_model # Initialize components self.embedding_model = None self.index = None self.documents = [] self.metadata = [] # Vector store files self.index_file = self.vector_store_dir / "faiss_index.bin" self.documents_file = self.vector_store_dir / "documents.json" self.metadata_file = self.vector_store_dir / "metadata.json" self.config_file = self.vector_store_dir / "config.json" # Search parameters self.default_k = 5 self.similarity_threshold = 0.3 def initialize_embedding_model(self): """Initialize the optimal embedding model""" logger.info(f"Initializing embedding model: {self.embedding_model_name}") try: self.embedding_model = SentenceTransformer(self.embedding_model_name) logger.info("āœ… Embedding model loaded successfully") # Get embedding dimension test_embedding = self.embedding_model.encode(["test"]) self.embedding_dimension = test_embedding.shape[1] logger.info(f"šŸ“ Embedding dimension: {self.embedding_dimension}") except Exception as e: logger.error(f"āŒ Failed to load embedding model: {e}") raise def load_medical_documents(self) -> List[Dict[str, Any]]: """Load processed medical documents""" logger.info("Loading medical documents for vector store...") langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" if not langchain_file.exists(): raise FileNotFoundError(f"Medical documents not found: {langchain_file}") with open(langchain_file, 'r', encoding='utf-8') as f: documents = json.load(f) logger.info(f"šŸ“š Loaded {len(documents)} medical document chunks") return documents def create_vector_index(self, force_rebuild: bool = False) -> bool: """Create or load FAISS vector index""" # Check if existing index can be loaded if not force_rebuild and self.index_file.exists(): try: return self.load_existing_index() except Exception as e: logger.warning(f"Failed to load existing index: {e}") logger.info("Rebuilding index from scratch...") # Initialize embedding model if not done if self.embedding_model is None: self.initialize_embedding_model() # Load documents documents = self.load_medical_documents() logger.info("Creating vector embeddings for all medical chunks...") # Extract content and metadata texts = [] metadata = [] for doc in documents: content = doc['page_content'] meta = doc['metadata'] # Skip very short chunks if len(content.strip()) < 50: continue texts.append(content) metadata.append(meta) # Generate embeddings in batches logger.info(f"Generating embeddings for {len(texts)} chunks...") start_time = time.time() embeddings = self.embedding_model.encode( texts, batch_size=32, show_progress_bar=True, convert_to_numpy=True ) embed_time = time.time() - start_time logger.info(f"⚔ Embeddings generated in {embed_time:.2f} seconds") # Create FAISS index logger.info("Building FAISS index...") # Use IndexFlatIP for inner product (cosine similarity) # Normalize embeddings for cosine similarity faiss.normalize_L2(embeddings) # Create index index = faiss.IndexFlatIP(self.embedding_dimension) index.add(embeddings.astype('float32')) # Store components self.index = index self.documents = texts self.metadata = metadata # Save to disk self.save_index() logger.info(f"āœ… Vector store created with {index.ntotal} embeddings") return True def load_existing_index(self) -> bool: """Load existing FAISS index from disk""" logger.info("Loading existing vector store...") try: # Load FAISS index self.index = faiss.read_index(str(self.index_file)) # Load documents with open(self.documents_file, 'r', encoding='utf-8') as f: self.documents = json.load(f) # Load metadata with open(self.metadata_file, 'r', encoding='utf-8') as f: self.metadata = json.load(f) # Load config with open(self.config_file, 'r') as f: config = json.load(f) self.embedding_model_name = config['embedding_model'] self.embedding_dimension = config['embedding_dimension'] # Initialize embedding model self.initialize_embedding_model() logger.info(f"āœ… Loaded existing vector store with {self.index.ntotal} embeddings") return True except Exception as e: logger.error(f"āŒ Failed to load existing index: {e}") return False def save_index(self): """Save FAISS index and metadata to disk""" logger.info("Saving vector store to disk...") try: # Save FAISS index faiss.write_index(self.index, str(self.index_file)) # Save documents with open(self.documents_file, 'w', encoding='utf-8') as f: json.dump(self.documents, f, ensure_ascii=False, indent=2) # Save metadata with open(self.metadata_file, 'w', encoding='utf-8') as f: json.dump(self.metadata, f, ensure_ascii=False, indent=2) # Save config config = { 'embedding_model': self.embedding_model_name, 'embedding_dimension': self.embedding_dimension, 'total_chunks': len(self.documents), 'created_at': time.strftime('%Y-%m-%d %H:%M:%S') } with open(self.config_file, 'w') as f: json.dump(config, f, indent=2) logger.info(f"šŸ’¾ Vector store saved to {self.vector_store_dir}") except Exception as e: logger.error(f"āŒ Failed to save vector store: {e}") raise def search(self, query: str, k: int = None, filters: Dict[str, Any] = None, min_score: float = None) -> List[SearchResult]: """Search for relevant medical content""" if self.index is None: raise ValueError("Vector store not initialized. Call create_vector_index() first.") if k is None: k = self.default_k if min_score is None: min_score = self.similarity_threshold # Generate query embedding query_embedding = self.embedding_model.encode([query]) faiss.normalize_L2(query_embedding) # Search in FAISS index scores, indices = self.index.search(query_embedding.astype('float32'), k * 2) # Get more for filtering # Process results results = [] for score, idx in zip(scores[0], indices[0]): if idx == -1 or score < min_score: continue # Get document and metadata content = self.documents[idx] metadata = self.metadata[idx] # Apply filters if specified if filters and not self._matches_filters(metadata, filters): continue # Create search result result = SearchResult( content=content, score=float(score), metadata=metadata, chunk_index=idx, source_document=metadata.get('source', ''), chunk_type=metadata.get('chunk_type', 'text'), clinical_importance=metadata.get('clinical_importance', 0.5) ) results.append(result) # Stop when we have enough results if len(results) >= k: break return results def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool: """Check if metadata matches the specified filters""" for key, value in filters.items(): if key not in metadata: return False meta_value = metadata[key] # Handle different filter types if isinstance(value, list): if meta_value not in value: return False elif isinstance(value, dict): if 'min' in value and meta_value < value['min']: return False if 'max' in value and meta_value > value['max']: return False else: if meta_value != value: return False return True def search_by_medical_context(self, query: str, content_types: List[str] = None, min_importance: float = 0.5, k: int = 5) -> List[SearchResult]: """Search with medical context filtering""" filters = {} # Filter by content types if content_types: filters['chunk_type'] = content_types # Filter by clinical importance if min_importance > 0: filters['clinical_importance'] = {'min': min_importance} return self.search(query, k=k, filters=filters) def get_statistics(self) -> Dict[str, Any]: """Get vector store statistics""" if self.index is None: return {'error': 'Vector store not initialized'} # Calculate statistics from metadata chunk_types = {} importance_distribution = {'low': 0, 'medium': 0, 'high': 0, 'critical': 0} sources = {} for meta in self.metadata: # Chunk types chunk_type = meta.get('chunk_type', 'unknown') chunk_types[chunk_type] = chunk_types.get(chunk_type, 0) + 1 # Importance distribution importance = meta.get('clinical_importance', 0) if importance >= 0.9: importance_distribution['critical'] += 1 elif importance >= 0.7: importance_distribution['high'] += 1 elif importance >= 0.5: importance_distribution['medium'] += 1 else: importance_distribution['low'] += 1 # Sources source = meta.get('source', 'unknown') sources[source] = sources.get(source, 0) + 1 return { 'total_chunks': self.index.ntotal, 'embedding_dimension': self.embedding_dimension, 'embedding_model': self.embedding_model_name, 'chunk_type_distribution': chunk_types, 'clinical_importance_distribution': importance_distribution, 'source_distribution': dict(list(sources.items())[:10]), # Top 10 sources 'vector_store_size_mb': self.index_file.stat().st_size / (1024*1024) if self.index_file.exists() else 0 } def main(): """Main function to create and test vector store""" logger.info("šŸš€ Creating Maternal Health Vector Store...") # Create vector store manager vector_store = MaternalHealthVectorStore() # Create the vector index success = vector_store.create_vector_index() if not success: logger.error("āŒ Failed to create vector store") return # Test searches logger.info("\nšŸ” Testing search functionality...") test_queries = [ "What is the recommended dosage of magnesium sulfate for preeclampsia?", "How to manage postpartum hemorrhage in emergency situations?", "Signs and symptoms of puerperal sepsis", "Normal fetal heart rate during labor" ] for query in test_queries: logger.info(f"\nšŸ“ Query: {query}") results = vector_store.search(query, k=3) for i, result in enumerate(results, 1): logger.info(f" {i}. Score: {result.score:.3f} | Type: {result.chunk_type} | " f"Importance: {result.clinical_importance:.2f}") logger.info(f" Content: {result.content[:100]}...") # Get statistics stats = vector_store.get_statistics() logger.info("\nšŸ“Š Vector Store Statistics:") logger.info(f" Total chunks: {stats['total_chunks']}") logger.info(f" Embedding dimension: {stats['embedding_dimension']}") logger.info(f" High importance chunks: {stats['clinical_importance_distribution']['high'] + stats['clinical_importance_distribution']['critical']}") logger.info(f" Vector store size: {stats['vector_store_size_mb']:.1f} MB") logger.info("\nāœ… Vector store creation and testing complete!") if __name__ == "__main__": main()