Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Test Suite for Maternal Health Vector Store | |
Validates search functionality, medical context filtering, and performance | |
""" | |
import unittest | |
import time | |
from pathlib import Path | |
from vector_store_manager import MaternalHealthVectorStore, SearchResult | |
class TestMaternalHealthVectorStore(unittest.TestCase): | |
"""Test suite for vector store functionality""" | |
def setUpClass(cls): | |
"""Set up test environment""" | |
cls.vector_store = MaternalHealthVectorStore() | |
# Load existing vector store (should exist from previous run) | |
if cls.vector_store.index_file.exists(): | |
print("Loading existing vector store for testing...") | |
success = cls.vector_store.load_existing_index() | |
if not success: | |
print("Failed to load existing index, creating new one...") | |
cls.vector_store.create_vector_index() | |
else: | |
print("Creating vector store for testing...") | |
cls.vector_store.create_vector_index() | |
def test_vector_store_initialization(self): | |
"""Test vector store loads correctly""" | |
self.assertIsNotNone(self.vector_store.index) | |
self.assertGreater(self.vector_store.index.ntotal, 0) | |
self.assertEqual(len(self.vector_store.documents), len(self.vector_store.metadata)) | |
def test_basic_search_functionality(self): | |
"""Test basic search returns relevant results""" | |
query = "magnesium sulfate dosage for preeclampsia" | |
results = self.vector_store.search(query, k=3) | |
# Should return results | |
self.assertGreater(len(results), 0) | |
self.assertLessEqual(len(results), 3) | |
# All results should be SearchResult objects | |
for result in results: | |
self.assertIsInstance(result, SearchResult) | |
self.assertGreater(result.score, 0) | |
self.assertIn('magnesium', result.content.lower()) | |
def test_medical_context_filtering(self): | |
"""Test filtering by medical content types""" | |
query = "emergency management protocols" | |
# Test filtering by emergency content | |
emergency_results = self.vector_store.search_by_medical_context( | |
query, | |
content_types=['emergency'], | |
min_importance=0.8, | |
k=5 | |
) | |
# Should return emergency-specific results | |
for result in emergency_results: | |
self.assertEqual(result.chunk_type, 'emergency') | |
self.assertGreaterEqual(result.clinical_importance, 0.8) | |
def test_clinical_importance_filtering(self): | |
"""Test filtering by clinical importance""" | |
query = "dosage recommendations" | |
# Test high importance filtering | |
high_importance_results = self.vector_store.search_by_medical_context( | |
query, | |
min_importance=0.9, | |
k=10 | |
) | |
# All results should have high clinical importance | |
for result in high_importance_results: | |
self.assertGreaterEqual(result.clinical_importance, 0.9) | |
def test_search_performance(self): | |
"""Test search performance is acceptable""" | |
query = "normal labor management guidelines" | |
start_time = time.time() | |
results = self.vector_store.search(query, k=5) | |
search_time = time.time() - start_time | |
# Search should be fast (under 1 second) | |
self.assertLess(search_time, 1.0) | |
self.assertGreater(len(results), 0) | |
def test_maternal_health_queries(self): | |
"""Test specific maternal health queries return relevant results""" | |
test_cases = [ | |
{ | |
'query': 'postpartum hemorrhage management', | |
'expected_keywords': ['hemorrhage', 'postpartum', 'bleeding'], | |
'min_score': 0.3 | |
}, | |
{ | |
'query': 'fetal heart rate monitoring', | |
'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'], | |
'min_score': 0.3 | |
}, | |
{ | |
'query': 'preeclampsia treatment protocols', | |
'expected_keywords': ['preeclampsia', 'treatment', 'protocol'], | |
'min_score': 0.3 | |
} | |
] | |
for case in test_cases: | |
with self.subTest(query=case['query']): | |
results = self.vector_store.search(case['query'], k=3) | |
# Should return results | |
self.assertGreater(len(results), 0) | |
# Check relevance | |
best_result = results[0] | |
self.assertGreaterEqual(best_result.score, case['min_score']) | |
# Check if keywords appear in results | |
combined_content = ' '.join([r.content.lower() for r in results]) | |
keyword_found = any( | |
keyword in combined_content | |
for keyword in case['expected_keywords'] | |
) | |
self.assertTrue(keyword_found, | |
f"No keywords {case['expected_keywords']} found in results") | |
def test_statistics_functionality(self): | |
"""Test vector store statistics are accurate""" | |
stats = self.vector_store.get_statistics() | |
# Check required fields | |
required_fields = [ | |
'total_chunks', 'embedding_dimension', 'embedding_model', | |
'chunk_type_distribution', 'clinical_importance_distribution' | |
] | |
for field in required_fields: | |
self.assertIn(field, stats) | |
# Check values make sense | |
self.assertGreater(stats['total_chunks'], 0) | |
self.assertEqual(stats['embedding_dimension'], 384) | |
self.assertIn('all-MiniLM-L6-v2', stats['embedding_model']) | |
def test_dosage_information_retrieval(self): | |
"""Test retrieval of dosage-specific information""" | |
dosage_queries = [ | |
{ | |
'query': "oxytocin dosage for labor induction", | |
'content_types': ['dosage', 'emergency', 'maternal', 'procedure'], # Include maternal and procedure | |
'dosage_terms': ['oxytocin', 'administration', 'dose', 'mg', 'ml', 'unit', 'continuous'] | |
}, | |
{ | |
'query': "antibiotic prophylaxis dosing", | |
'content_types': ['dosage', 'emergency'], | |
'dosage_terms': ['mg', 'ml', 'dose', 'dosage', 'antibiotic', 'prophylaxis'] | |
}, | |
{ | |
'query': "magnesium sulfate administration", | |
'content_types': ['dosage', 'emergency'], | |
'dosage_terms': ['magnesium', 'sulfate', 'mg', 'dose', 'administration'] | |
} | |
] | |
for case in dosage_queries: | |
with self.subTest(query=case['query']): | |
results = self.vector_store.search_by_medical_context( | |
case['query'], | |
content_types=case['content_types'], | |
k=3 | |
) | |
# Should find dosage-related content | |
self.assertGreater(len(results), 0) | |
# Check for dosage-related terms | |
combined_content = ' '.join([r.content.lower() for r in results]) | |
term_found = any(term in combined_content for term in case['dosage_terms']) | |
self.assertTrue(term_found, | |
f"No dosage terms {case['dosage_terms']} found for query: {case['query']}") | |
def test_edge_cases(self): | |
"""Test edge cases and error handling""" | |
# Empty query | |
results = self.vector_store.search("", k=1) | |
self.assertIsInstance(results, list) | |
# Very specific query that might not match well | |
results = self.vector_store.search("xyz unknown medical term", k=1) | |
self.assertIsInstance(results, list) | |
# Large k value | |
results = self.vector_store.search("pregnancy", k=100) | |
self.assertLessEqual(len(results), 100) | |
def run_comprehensive_tests(): | |
"""Run all tests and provide detailed report""" | |
print("🧪 Running Comprehensive Vector Store Tests...") | |
print("=" * 60) | |
# Create test suite | |
loader = unittest.TestLoader() | |
suite = loader.loadTestsFromTestCase(TestMaternalHealthVectorStore) | |
# Run tests with detailed output | |
runner = unittest.TextTestRunner(verbosity=2) | |
result = runner.run(suite) | |
# Print summary | |
print("\n" + "=" * 60) | |
print("📊 TEST SUMMARY:") | |
print(f" Tests run: {result.testsRun}") | |
print(f" Failures: {len(result.failures)}") | |
print(f" Errors: {len(result.errors)}") | |
if result.wasSuccessful(): | |
print("✅ ALL TESTS PASSED! Vector store is working perfectly.") | |
else: | |
print("❌ Some tests failed. Check output above for details.") | |
if result.failures: | |
print("\nFailures:") | |
for test, traceback in result.failures: | |
# Extract the last meaningful line from traceback | |
lines = traceback.strip().split('\n') | |
error_line = lines[-1] if lines else "Unknown failure" | |
print(f" - {test}: {error_line}") | |
if result.errors: | |
print("\nErrors:") | |
for test, traceback in result.errors: | |
# Extract the last meaningful line from traceback | |
lines = traceback.strip().split('\n') | |
error_line = lines[-1] if lines else "Unknown error" | |
print(f" - {test}: {error_line}") | |
return result.wasSuccessful() | |
if __name__ == "__main__": | |
success = run_comprehensive_tests() | |
exit(0 if success else 1) |