Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Vector Store Manager for Maternal Health RAG Chatbot | |
Uses FAISS with the optimal all-MiniLM-L6-v2 embedding model | |
""" | |
import json | |
import numpy as np | |
import faiss | |
from pathlib import Path | |
from typing import List, Dict, Any, Tuple, Optional | |
import logging | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import time | |
from dataclasses import dataclass | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class SearchResult: | |
"""Container for search results""" | |
content: str | |
score: float | |
metadata: Dict[str, Any] | |
chunk_index: int | |
source_document: str | |
chunk_type: str | |
clinical_importance: float | |
class MaternalHealthVectorStore: | |
"""Vector store for maternal health guidelines with clinical context filtering""" | |
def __init__(self, | |
vector_store_dir: str = "vector_store", | |
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", | |
chunks_dir: str = "comprehensive_chunks"): | |
self.vector_store_dir = Path(vector_store_dir) | |
self.vector_store_dir.mkdir(exist_ok=True) | |
self.chunks_dir = Path(chunks_dir) | |
self.embedding_model_name = embedding_model | |
# Initialize components | |
self.embedding_model = None | |
self.index = None | |
self.documents = [] | |
self.metadata = [] | |
# Vector store files | |
self.index_file = self.vector_store_dir / "faiss_index.bin" | |
self.documents_file = self.vector_store_dir / "documents.json" | |
self.metadata_file = self.vector_store_dir / "metadata.json" | |
self.config_file = self.vector_store_dir / "config.json" | |
# Search parameters | |
self.default_k = 5 | |
self.similarity_threshold = 0.3 | |
def initialize_embedding_model(self): | |
"""Initialize the optimal embedding model""" | |
logger.info(f"Initializing embedding model: {self.embedding_model_name}") | |
try: | |
self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
logger.info("✅ Embedding model loaded successfully") | |
# Get embedding dimension | |
test_embedding = self.embedding_model.encode(["test"]) | |
self.embedding_dimension = test_embedding.shape[1] | |
logger.info(f"📏 Embedding dimension: {self.embedding_dimension}") | |
except Exception as e: | |
logger.error(f"❌ Failed to load embedding model: {e}") | |
raise | |
def load_medical_documents(self) -> List[Dict[str, Any]]: | |
"""Load processed medical documents""" | |
logger.info("Loading medical documents for vector store...") | |
langchain_file = self.chunks_dir / "langchain_documents_comprehensive.json" | |
if not langchain_file.exists(): | |
raise FileNotFoundError(f"Medical documents not found: {langchain_file}") | |
with open(langchain_file, 'r', encoding='utf-8') as f: | |
documents = json.load(f) | |
logger.info(f"📚 Loaded {len(documents)} medical document chunks") | |
return documents | |
def create_vector_index(self, force_rebuild: bool = False) -> bool: | |
"""Create or load FAISS vector index""" | |
# Check if existing index can be loaded | |
if not force_rebuild and self.index_file.exists(): | |
try: | |
return self.load_existing_index() | |
except Exception as e: | |
logger.warning(f"Failed to load existing index: {e}") | |
logger.info("Rebuilding index from scratch...") | |
# Initialize embedding model if not done | |
if self.embedding_model is None: | |
self.initialize_embedding_model() | |
# Load documents | |
documents = self.load_medical_documents() | |
logger.info("Creating vector embeddings for all medical chunks...") | |
# Extract content and metadata | |
texts = [] | |
metadata = [] | |
for doc in documents: | |
content = doc['page_content'] | |
meta = doc['metadata'] | |
# Skip very short chunks | |
if len(content.strip()) < 50: | |
continue | |
texts.append(content) | |
metadata.append(meta) | |
# Generate embeddings in batches | |
logger.info(f"Generating embeddings for {len(texts)} chunks...") | |
start_time = time.time() | |
embeddings = self.embedding_model.encode( | |
texts, | |
batch_size=32, | |
show_progress_bar=True, | |
convert_to_numpy=True | |
) | |
embed_time = time.time() - start_time | |
logger.info(f"⚡ Embeddings generated in {embed_time:.2f} seconds") | |
# Create FAISS index | |
logger.info("Building FAISS index...") | |
# Use IndexFlatIP for inner product (cosine similarity) | |
# Normalize embeddings for cosine similarity | |
faiss.normalize_L2(embeddings) | |
# Create index | |
index = faiss.IndexFlatIP(self.embedding_dimension) | |
index.add(embeddings.astype('float32')) | |
# Store components | |
self.index = index | |
self.documents = texts | |
self.metadata = metadata | |
# Save to disk | |
self.save_index() | |
logger.info(f"✅ Vector store created with {index.ntotal} embeddings") | |
return True | |
def load_existing_index(self) -> bool: | |
"""Load existing FAISS index from disk""" | |
logger.info("Loading existing vector store...") | |
try: | |
# Load FAISS index | |
self.index = faiss.read_index(str(self.index_file)) | |
# Load documents | |
with open(self.documents_file, 'r', encoding='utf-8') as f: | |
self.documents = json.load(f) | |
# Load metadata | |
with open(self.metadata_file, 'r', encoding='utf-8') as f: | |
self.metadata = json.load(f) | |
# Load config | |
with open(self.config_file, 'r') as f: | |
config = json.load(f) | |
self.embedding_model_name = config['embedding_model'] | |
self.embedding_dimension = config['embedding_dimension'] | |
# Initialize embedding model | |
self.initialize_embedding_model() | |
logger.info(f"✅ Loaded existing vector store with {self.index.ntotal} embeddings") | |
return True | |
except Exception as e: | |
logger.error(f"❌ Failed to load existing index: {e}") | |
return False | |
def save_index(self): | |
"""Save FAISS index and metadata to disk""" | |
logger.info("Saving vector store to disk...") | |
try: | |
# Save FAISS index | |
faiss.write_index(self.index, str(self.index_file)) | |
# Save documents | |
with open(self.documents_file, 'w', encoding='utf-8') as f: | |
json.dump(self.documents, f, ensure_ascii=False, indent=2) | |
# Save metadata | |
with open(self.metadata_file, 'w', encoding='utf-8') as f: | |
json.dump(self.metadata, f, ensure_ascii=False, indent=2) | |
# Save config | |
config = { | |
'embedding_model': self.embedding_model_name, | |
'embedding_dimension': self.embedding_dimension, | |
'total_chunks': len(self.documents), | |
'created_at': time.strftime('%Y-%m-%d %H:%M:%S') | |
} | |
with open(self.config_file, 'w') as f: | |
json.dump(config, f, indent=2) | |
logger.info(f"💾 Vector store saved to {self.vector_store_dir}") | |
except Exception as e: | |
logger.error(f"❌ Failed to save vector store: {e}") | |
raise | |
def search(self, | |
query: str, | |
k: int = None, | |
filters: Dict[str, Any] = None, | |
min_score: float = None) -> List[SearchResult]: | |
"""Search for relevant medical content""" | |
if self.index is None: | |
raise ValueError("Vector store not initialized. Call create_vector_index() first.") | |
if k is None: | |
k = self.default_k | |
if min_score is None: | |
min_score = self.similarity_threshold | |
# Generate query embedding | |
query_embedding = self.embedding_model.encode([query]) | |
faiss.normalize_L2(query_embedding) | |
# Search in FAISS index | |
scores, indices = self.index.search(query_embedding.astype('float32'), k * 2) # Get more for filtering | |
# Process results | |
results = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx == -1 or score < min_score: | |
continue | |
# Get document and metadata | |
content = self.documents[idx] | |
metadata = self.metadata[idx] | |
# Apply filters if specified | |
if filters and not self._matches_filters(metadata, filters): | |
continue | |
# Create search result | |
result = SearchResult( | |
content=content, | |
score=float(score), | |
metadata=metadata, | |
chunk_index=idx, | |
source_document=metadata.get('source', ''), | |
chunk_type=metadata.get('chunk_type', 'text'), | |
clinical_importance=metadata.get('clinical_importance', 0.5) | |
) | |
results.append(result) | |
# Stop when we have enough results | |
if len(results) >= k: | |
break | |
return results | |
def _matches_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool: | |
"""Check if metadata matches the specified filters""" | |
for key, value in filters.items(): | |
if key not in metadata: | |
return False | |
meta_value = metadata[key] | |
# Handle different filter types | |
if isinstance(value, list): | |
if meta_value not in value: | |
return False | |
elif isinstance(value, dict): | |
if 'min' in value and meta_value < value['min']: | |
return False | |
if 'max' in value and meta_value > value['max']: | |
return False | |
else: | |
if meta_value != value: | |
return False | |
return True | |
def search_by_medical_context(self, | |
query: str, | |
content_types: List[str] = None, | |
min_importance: float = 0.5, | |
k: int = 5) -> List[SearchResult]: | |
"""Search with medical context filtering""" | |
filters = {} | |
# Filter by content types | |
if content_types: | |
filters['chunk_type'] = content_types | |
# Filter by clinical importance | |
if min_importance > 0: | |
filters['clinical_importance'] = {'min': min_importance} | |
return self.search(query, k=k, filters=filters) | |
def get_statistics(self) -> Dict[str, Any]: | |
"""Get vector store statistics""" | |
if self.index is None: | |
return {'error': 'Vector store not initialized'} | |
# Calculate statistics from metadata | |
chunk_types = {} | |
importance_distribution = {'low': 0, 'medium': 0, 'high': 0, 'critical': 0} | |
sources = {} | |
for meta in self.metadata: | |
# Chunk types | |
chunk_type = meta.get('chunk_type', 'unknown') | |
chunk_types[chunk_type] = chunk_types.get(chunk_type, 0) + 1 | |
# Importance distribution | |
importance = meta.get('clinical_importance', 0) | |
if importance >= 0.9: | |
importance_distribution['critical'] += 1 | |
elif importance >= 0.7: | |
importance_distribution['high'] += 1 | |
elif importance >= 0.5: | |
importance_distribution['medium'] += 1 | |
else: | |
importance_distribution['low'] += 1 | |
# Sources | |
source = meta.get('source', 'unknown') | |
sources[source] = sources.get(source, 0) + 1 | |
return { | |
'total_chunks': self.index.ntotal, | |
'embedding_dimension': self.embedding_dimension, | |
'embedding_model': self.embedding_model_name, | |
'chunk_type_distribution': chunk_types, | |
'clinical_importance_distribution': importance_distribution, | |
'source_distribution': dict(list(sources.items())[:10]), # Top 10 sources | |
'vector_store_size_mb': self.index_file.stat().st_size / (1024*1024) if self.index_file.exists() else 0 | |
} | |
def main(): | |
"""Main function to create and test vector store""" | |
logger.info("🚀 Creating Maternal Health Vector Store...") | |
# Create vector store manager | |
vector_store = MaternalHealthVectorStore() | |
# Create the vector index | |
success = vector_store.create_vector_index() | |
if not success: | |
logger.error("❌ Failed to create vector store") | |
return | |
# Test searches | |
logger.info("\n🔍 Testing search functionality...") | |
test_queries = [ | |
"What is the recommended dosage of magnesium sulfate for preeclampsia?", | |
"How to manage postpartum hemorrhage in emergency situations?", | |
"Signs and symptoms of puerperal sepsis", | |
"Normal fetal heart rate during labor" | |
] | |
for query in test_queries: | |
logger.info(f"\n📝 Query: {query}") | |
results = vector_store.search(query, k=3) | |
for i, result in enumerate(results, 1): | |
logger.info(f" {i}. Score: {result.score:.3f} | Type: {result.chunk_type} | " | |
f"Importance: {result.clinical_importance:.2f}") | |
logger.info(f" Content: {result.content[:100]}...") | |
# Get statistics | |
stats = vector_store.get_statistics() | |
logger.info("\n📊 Vector Store Statistics:") | |
logger.info(f" Total chunks: {stats['total_chunks']}") | |
logger.info(f" Embedding dimension: {stats['embedding_dimension']}") | |
logger.info(f" High importance chunks: {stats['clinical_importance_distribution']['high'] + stats['clinical_importance_distribution']['critical']}") | |
logger.info(f" Vector store size: {stats['vector_store_size_mb']:.1f} MB") | |
logger.info("\n✅ Vector store creation and testing complete!") | |
if __name__ == "__main__": | |
main() |