Spaces:
Sleeping
Sleeping
""" | |
Simple Vector Store for Medical RAG v2.0 | |
Research-backed approach: Document-based retrieval with simple metadata | |
""" | |
import os | |
import json | |
import logging | |
import time | |
from typing import List, Dict, Any, Optional, Tuple | |
from pathlib import Path | |
import numpy as np | |
from dataclasses import dataclass | |
# Vector store and embeddings | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from langchain_core.documents import Document | |
class SearchResult: | |
"""Simple search result structure""" | |
content: str | |
score: float | |
metadata: Dict[str, Any] | |
document_name: str | |
content_type: str | |
class SimpleVectorStore: | |
""" | |
Simple vector store using research-optimal embedding approach | |
- Focused on document-based retrieval | |
- Simplified metadata structure | |
- High-performance FAISS indexing | |
""" | |
def __init__(self, | |
embedding_model: str = "all-MiniLM-L6-v2", | |
index_type: str = "IndexFlatIP", # Inner Product for cosine similarity | |
vector_store_dir: str = "simple_vector_store"): | |
""" | |
Initialize the simple vector store | |
Args: | |
embedding_model: Sentence transformer model name | |
index_type: FAISS index type | |
vector_store_dir: Directory to store vector index and metadata | |
""" | |
self.embedding_model_name = embedding_model | |
self.index_type = index_type | |
self.vector_store_dir = Path(vector_store_dir) | |
self.vector_store_dir.mkdir(exist_ok=True) | |
# Initialize components | |
self.embedding_model = None | |
self.index = None | |
self.documents = [] | |
self.metadata = [] | |
self.setup_logging() | |
self._initialize_embedding_model() | |
def setup_logging(self): | |
"""Setup logging for the vector store""" | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
self.logger = logging.getLogger(__name__) | |
def _initialize_embedding_model(self): | |
"""Initialize the sentence transformer model""" | |
try: | |
self.logger.info(f"Loading embedding model: {self.embedding_model_name}") | |
self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
self.logger.info(f"Embedding model loaded successfully") | |
except Exception as e: | |
self.logger.error(f"Error loading embedding model: {e}") | |
raise | |
def create_embeddings(self, chunks: List[Document]) -> Tuple[np.ndarray, int]: | |
"""Create embeddings for document chunks""" | |
if not chunks: | |
raise ValueError("No chunks provided for embedding") | |
start_time = time.time() | |
# Extract text content | |
texts = [chunk.page_content for chunk in chunks] | |
self.logger.info(f"Creating embeddings for {len(texts)} chunks...") | |
# Generate embeddings | |
embeddings = self.embedding_model.encode( | |
texts, | |
show_progress_bar=True, | |
batch_size=32, | |
normalize_embeddings=True # Important for cosine similarity | |
) | |
# Store documents and metadata | |
self.documents = chunks | |
self.metadata = [chunk.metadata for chunk in chunks] | |
embedding_time = time.time() - start_time | |
self.logger.info(f"Created {len(embeddings)} embeddings in {embedding_time:.2f} seconds") | |
return embeddings, len(embeddings) | |
def build_index(self, embeddings: np.ndarray): | |
"""Build FAISS index from embeddings""" | |
dimension = embeddings.shape[1] | |
# Create FAISS index | |
if self.index_type == "IndexFlatIP": | |
# Inner Product index (good for normalized embeddings) | |
self.index = faiss.IndexFlatIP(dimension) | |
elif self.index_type == "IndexFlatL2": | |
# L2 distance index | |
self.index = faiss.IndexFlatL2(dimension) | |
else: | |
raise ValueError(f"Unsupported index type: {self.index_type}") | |
# Add embeddings to index | |
self.index.add(embeddings.astype('float32')) | |
self.logger.info(f"Built FAISS index with {self.index.ntotal} vectors") | |
def save_vector_store(self): | |
"""Save vector store to disk""" | |
try: | |
# Save FAISS index | |
index_path = self.vector_store_dir / "faiss_index.bin" | |
faiss.write_index(self.index, str(index_path)) | |
# Save documents | |
docs_path = self.vector_store_dir / "documents.json" | |
docs_data = [] | |
for doc in self.documents: | |
docs_data.append({ | |
'page_content': doc.page_content, | |
'metadata': doc.metadata | |
}) | |
with open(docs_path, 'w', encoding='utf-8') as f: | |
json.dump(docs_data, f, indent=2, ensure_ascii=False) | |
# Save configuration | |
config_path = self.vector_store_dir / "config.json" | |
config = { | |
'embedding_model': self.embedding_model_name, | |
'index_type': self.index_type, | |
'total_documents': len(self.documents), | |
'dimension': self.index.d if self.index else 0, | |
'created_at': time.strftime('%Y-%m-%d %H:%M:%S') | |
} | |
with open(config_path, 'w', encoding='utf-8') as f: | |
json.dump(config, f, indent=2) | |
self.logger.info(f"Vector store saved to {self.vector_store_dir}") | |
except Exception as e: | |
self.logger.error(f"Error saving vector store: {e}") | |
raise | |
def load_vector_store(self) -> bool: | |
"""Load vector store from disk""" | |
try: | |
index_path = self.vector_store_dir / "faiss_index.bin" | |
docs_path = self.vector_store_dir / "documents.json" | |
config_path = self.vector_store_dir / "config.json" | |
if not all(p.exists() for p in [index_path, docs_path, config_path]): | |
return False | |
# Load FAISS index | |
self.index = faiss.read_index(str(index_path)) | |
# Load documents | |
with open(docs_path, 'r', encoding='utf-8') as f: | |
docs_data = json.load(f) | |
self.documents = [] | |
self.metadata = [] | |
for doc_data in docs_data: | |
doc = Document( | |
page_content=doc_data['page_content'], | |
metadata=doc_data['metadata'] | |
) | |
self.documents.append(doc) | |
self.metadata.append(doc_data['metadata']) | |
# Load configuration | |
with open(config_path, 'r', encoding='utf-8') as f: | |
config = json.load(f) | |
self.logger.info(f"Loaded vector store with {len(self.documents)} documents") | |
return True | |
except Exception as e: | |
self.logger.error(f"Error loading vector store: {e}") | |
return False | |
def search(self, | |
query: str, | |
k: int = 5, | |
content_type_filter: Optional[str] = None) -> List[SearchResult]: | |
""" | |
Search for similar documents | |
Args: | |
query: Search query | |
k: Number of results to return | |
content_type_filter: Filter by content type (optional) | |
Returns: | |
List of SearchResult objects | |
""" | |
if not self.index or not self.documents: | |
raise ValueError("Vector store not initialized. Load or create index first.") | |
# Create query embedding | |
query_embedding = self.embedding_model.encode( | |
[query], | |
normalize_embeddings=True | |
) | |
# Search in FAISS index | |
# Get more results initially for filtering | |
search_k = min(k * 3, len(self.documents)) | |
scores, indices = self.index.search(query_embedding.astype('float32'), search_k) | |
# Process results | |
results = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx == -1: # Invalid index | |
continue | |
doc = self.documents[idx] | |
metadata = self.metadata[idx] | |
# Apply content type filter if specified | |
if content_type_filter: | |
doc_content_type = metadata.get('content_type', '') | |
if content_type_filter.lower() not in doc_content_type.lower(): | |
continue | |
result = SearchResult( | |
content=doc.page_content, | |
score=float(score), | |
metadata=metadata, | |
document_name=metadata.get('document_name', 'Unknown'), | |
content_type=metadata.get('content_type', 'general') | |
) | |
results.append(result) | |
# Stop when we have enough results | |
if len(results) >= k: | |
break | |
return results | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get vector store statistics""" | |
if not self.documents: | |
return {"status": "empty"} | |
# Document statistics | |
doc_counts = {} | |
content_type_counts = {} | |
total_chars = 0 | |
for doc in self.documents: | |
# Document distribution | |
doc_name = doc.metadata.get('document_name', 'Unknown') | |
doc_counts[doc_name] = doc_counts.get(doc_name, 0) + 1 | |
# Content type distribution | |
content_type = doc.metadata.get('content_type', 'general') | |
content_type_counts[content_type] = content_type_counts.get(content_type, 0) + 1 | |
# Character count | |
total_chars += len(doc.page_content) | |
# Vector store size estimation | |
if self.index: | |
# Estimate size: vectors + metadata | |
vector_size_mb = (self.index.ntotal * self.index.d * 4) / (1024 * 1024) # 4 bytes per float32 | |
metadata_size_mb = total_chars / (1024 * 1024) # Rough estimate | |
total_size_mb = vector_size_mb + metadata_size_mb | |
else: | |
total_size_mb = 0 | |
return { | |
"status": "ready", | |
"total_chunks": len(self.documents), | |
"embedding_model": self.embedding_model_name, | |
"index_type": self.index_type, | |
"vector_dimension": self.index.d if self.index else 0, | |
"vector_store_size_mb": round(total_size_mb, 2), | |
"avg_chunk_size": round(total_chars / len(self.documents), 1), | |
"document_distribution": dict(sorted(doc_counts.items(), key=lambda x: x[1], reverse=True)[:10]), | |
"content_type_distribution": content_type_counts | |
} | |
def main(): | |
"""Main function to test the simple vector store""" | |
print("π Testing Simple Vector Store v2.0") | |
print("=" * 60) | |
try: | |
# Initialize vector store | |
vector_store = SimpleVectorStore( | |
embedding_model="all-MiniLM-L6-v2", | |
index_type="IndexFlatIP" | |
) | |
# Check if we can load existing vector store | |
if vector_store.load_vector_store(): | |
print("β Loaded existing vector store") | |
else: | |
print("π Creating new vector store from chunks...") | |
# Load chunks from simple chunker | |
from simple_document_chunker import SimpleDocumentChunker | |
chunker = SimpleDocumentChunker() | |
documents = chunker.load_processed_documents() | |
chunks = chunker.create_simple_chunks(documents) | |
print(f"β Loaded {len(chunks)} chunks for embedding") | |
# Create embeddings | |
embeddings, count = vector_store.create_embeddings(chunks) | |
# Build index | |
vector_store.build_index(embeddings) | |
# Save vector store | |
vector_store.save_vector_store() | |
print("β Vector store created and saved") | |
# Get statistics | |
stats = vector_store.get_stats() | |
print(f"\nπ VECTOR STORE STATISTICS:") | |
print(f" Status: {stats['status'].upper()}") | |
print(f" Total chunks: {stats['total_chunks']:,}") | |
print(f" Embedding model: {stats['embedding_model']}") | |
print(f" Vector dimension: {stats['vector_dimension']}") | |
print(f" Store size: {stats['vector_store_size_mb']} MB") | |
print(f" Avg chunk size: {stats['avg_chunk_size']:.0f} chars") | |
print(f"\nπ Content Type Distribution:") | |
for content_type, count in stats['content_type_distribution'].items(): | |
percentage = (count / stats['total_chunks']) * 100 | |
print(f" {content_type}: {count:,} chunks ({percentage:.1f}%)") | |
# Test search functionality | |
print(f"\nπ TESTING SEARCH FUNCTIONALITY:") | |
test_queries = [ | |
"magnesium sulfate dosage preeclampsia", | |
"postpartum hemorrhage management", | |
"fetal heart rate monitoring", | |
"emergency cesarean delivery" | |
] | |
for query in test_queries: | |
print(f"\nπ Query: '{query}'") | |
results = vector_store.search(query, k=3) | |
for i, result in enumerate(results, 1): | |
print(f" Result {i}: Score={result.score:.3f}, Doc={result.document_name}") | |
print(f" Type={result.content_type}") | |
print(f" Preview: {result.content[:100]}...") | |
print(f"\nπ Simple Vector Store Testing Complete!") | |
print(f"β Successfully created vector store with {stats['total_chunks']:,} embeddings") | |
print(f"β Search functionality working with high relevance scores") | |
return vector_store | |
except Exception as e: | |
print(f"β Error in simple vector store: {e}") | |
import traceback | |
traceback.print_exc() | |
return None | |
if __name__ == "__main__": | |
main() |