vedaMD / src /test_vector_store.py
sniro23's picture
Initial commit without binary files
19aaa42
#!/usr/bin/env python3
"""
Test Suite for Maternal Health Vector Store
Validates search functionality, medical context filtering, and performance
"""
import unittest
import time
from pathlib import Path
from vector_store_manager import MaternalHealthVectorStore, SearchResult
class TestMaternalHealthVectorStore(unittest.TestCase):
"""Test suite for vector store functionality"""
@classmethod
def setUpClass(cls):
"""Set up test environment"""
cls.vector_store = MaternalHealthVectorStore()
# Load existing vector store (should exist from previous run)
if cls.vector_store.index_file.exists():
print("Loading existing vector store for testing...")
success = cls.vector_store.load_existing_index()
if not success:
print("Failed to load existing index, creating new one...")
cls.vector_store.create_vector_index()
else:
print("Creating vector store for testing...")
cls.vector_store.create_vector_index()
def test_vector_store_initialization(self):
"""Test vector store loads correctly"""
self.assertIsNotNone(self.vector_store.index)
self.assertGreater(self.vector_store.index.ntotal, 0)
self.assertEqual(len(self.vector_store.documents), len(self.vector_store.metadata))
def test_basic_search_functionality(self):
"""Test basic search returns relevant results"""
query = "magnesium sulfate dosage for preeclampsia"
results = self.vector_store.search(query, k=3)
# Should return results
self.assertGreater(len(results), 0)
self.assertLessEqual(len(results), 3)
# All results should be SearchResult objects
for result in results:
self.assertIsInstance(result, SearchResult)
self.assertGreater(result.score, 0)
self.assertIn('magnesium', result.content.lower())
def test_medical_context_filtering(self):
"""Test filtering by medical content types"""
query = "emergency management protocols"
# Test filtering by emergency content
emergency_results = self.vector_store.search_by_medical_context(
query,
content_types=['emergency'],
min_importance=0.8,
k=5
)
# Should return emergency-specific results
for result in emergency_results:
self.assertEqual(result.chunk_type, 'emergency')
self.assertGreaterEqual(result.clinical_importance, 0.8)
def test_clinical_importance_filtering(self):
"""Test filtering by clinical importance"""
query = "dosage recommendations"
# Test high importance filtering
high_importance_results = self.vector_store.search_by_medical_context(
query,
min_importance=0.9,
k=10
)
# All results should have high clinical importance
for result in high_importance_results:
self.assertGreaterEqual(result.clinical_importance, 0.9)
def test_search_performance(self):
"""Test search performance is acceptable"""
query = "normal labor management guidelines"
start_time = time.time()
results = self.vector_store.search(query, k=5)
search_time = time.time() - start_time
# Search should be fast (under 1 second)
self.assertLess(search_time, 1.0)
self.assertGreater(len(results), 0)
def test_maternal_health_queries(self):
"""Test specific maternal health queries return relevant results"""
test_cases = [
{
'query': 'postpartum hemorrhage management',
'expected_keywords': ['hemorrhage', 'postpartum', 'bleeding'],
'min_score': 0.3
},
{
'query': 'fetal heart rate monitoring',
'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'],
'min_score': 0.3
},
{
'query': 'preeclampsia treatment protocols',
'expected_keywords': ['preeclampsia', 'treatment', 'protocol'],
'min_score': 0.3
}
]
for case in test_cases:
with self.subTest(query=case['query']):
results = self.vector_store.search(case['query'], k=3)
# Should return results
self.assertGreater(len(results), 0)
# Check relevance
best_result = results[0]
self.assertGreaterEqual(best_result.score, case['min_score'])
# Check if keywords appear in results
combined_content = ' '.join([r.content.lower() for r in results])
keyword_found = any(
keyword in combined_content
for keyword in case['expected_keywords']
)
self.assertTrue(keyword_found,
f"No keywords {case['expected_keywords']} found in results")
def test_statistics_functionality(self):
"""Test vector store statistics are accurate"""
stats = self.vector_store.get_statistics()
# Check required fields
required_fields = [
'total_chunks', 'embedding_dimension', 'embedding_model',
'chunk_type_distribution', 'clinical_importance_distribution'
]
for field in required_fields:
self.assertIn(field, stats)
# Check values make sense
self.assertGreater(stats['total_chunks'], 0)
self.assertEqual(stats['embedding_dimension'], 384)
self.assertIn('all-MiniLM-L6-v2', stats['embedding_model'])
def test_dosage_information_retrieval(self):
"""Test retrieval of dosage-specific information"""
dosage_queries = [
{
'query': "oxytocin dosage for labor induction",
'content_types': ['dosage', 'emergency', 'maternal', 'procedure'], # Include maternal and procedure
'dosage_terms': ['oxytocin', 'administration', 'dose', 'mg', 'ml', 'unit', 'continuous']
},
{
'query': "antibiotic prophylaxis dosing",
'content_types': ['dosage', 'emergency'],
'dosage_terms': ['mg', 'ml', 'dose', 'dosage', 'antibiotic', 'prophylaxis']
},
{
'query': "magnesium sulfate administration",
'content_types': ['dosage', 'emergency'],
'dosage_terms': ['magnesium', 'sulfate', 'mg', 'dose', 'administration']
}
]
for case in dosage_queries:
with self.subTest(query=case['query']):
results = self.vector_store.search_by_medical_context(
case['query'],
content_types=case['content_types'],
k=3
)
# Should find dosage-related content
self.assertGreater(len(results), 0)
# Check for dosage-related terms
combined_content = ' '.join([r.content.lower() for r in results])
term_found = any(term in combined_content for term in case['dosage_terms'])
self.assertTrue(term_found,
f"No dosage terms {case['dosage_terms']} found for query: {case['query']}")
def test_edge_cases(self):
"""Test edge cases and error handling"""
# Empty query
results = self.vector_store.search("", k=1)
self.assertIsInstance(results, list)
# Very specific query that might not match well
results = self.vector_store.search("xyz unknown medical term", k=1)
self.assertIsInstance(results, list)
# Large k value
results = self.vector_store.search("pregnancy", k=100)
self.assertLessEqual(len(results), 100)
def run_comprehensive_tests():
"""Run all tests and provide detailed report"""
print("🧪 Running Comprehensive Vector Store Tests...")
print("=" * 60)
# Create test suite
loader = unittest.TestLoader()
suite = loader.loadTestsFromTestCase(TestMaternalHealthVectorStore)
# Run tests with detailed output
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Print summary
print("\n" + "=" * 60)
print("📊 TEST SUMMARY:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
if result.wasSuccessful():
print("✅ ALL TESTS PASSED! Vector store is working perfectly.")
else:
print("❌ Some tests failed. Check output above for details.")
if result.failures:
print("\nFailures:")
for test, traceback in result.failures:
# Extract the last meaningful line from traceback
lines = traceback.strip().split('\n')
error_line = lines[-1] if lines else "Unknown failure"
print(f" - {test}: {error_line}")
if result.errors:
print("\nErrors:")
for test, traceback in result.errors:
# Extract the last meaningful line from traceback
lines = traceback.strip().split('\n')
error_line = lines[-1] if lines else "Unknown error"
print(f" - {test}: {error_line}")
return result.wasSuccessful()
if __name__ == "__main__":
success = run_comprehensive_tests()
exit(0 if success else 1)