Spaces:
Sleeping
Sleeping
File size: 11,852 Bytes
19aaa42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
#!/usr/bin/env python3
"""
Test Suite for Maternal Health RAG System
Validates complete RAG pipeline including vector retrieval and response generation
"""
import unittest
import time
from typing import List, Dict, Any
from pathlib import Path
from maternal_health_rag import MaternalHealthRAG, QueryResponse
class TestMaternalHealthRAG(unittest.TestCase):
"""Test suite for RAG system functionality"""
@classmethod
def setUpClass(cls):
"""Set up test environment"""
print("π Initializing RAG system for testing...")
cls.rag_system = MaternalHealthRAG(use_mock_llm=True)
def test_rag_system_initialization(self):
"""Test RAG system initializes correctly"""
self.assertIsNotNone(self.rag_system.vector_store)
self.assertIsNotNone(self.rag_system.llm)
self.assertIsNotNone(self.rag_system.rag_chain)
# Check system status
stats = self.rag_system.get_system_stats()
self.assertEqual(stats['status'], 'initialized')
self.assertGreater(stats['vector_store']['total_chunks'], 0)
def test_basic_query_processing(self):
"""Test basic query processing functionality"""
query = "What is magnesium sulfate used for?"
response = self.rag_system.query(query)
# Basic response validation
self.assertIsInstance(response, QueryResponse)
self.assertEqual(response.query, query)
self.assertIsInstance(response.answer, str)
self.assertGreater(len(response.answer), 0)
self.assertGreaterEqual(response.confidence, 0.0)
self.assertLessEqual(response.confidence, 1.0)
self.assertGreater(response.response_time, 0.0)
def test_medical_context_queries(self):
"""Test queries with specific medical context"""
test_cases = [
{
'query': 'What is the dosage for magnesium sulfate in preeclampsia?',
'content_types': ['dosage', 'emergency'],
'expected_keywords': ['magnesium', 'dosage', 'preeclampsia'],
'min_confidence': 0.3
},
{
'query': 'How to manage postpartum hemorrhage emergency?',
'content_types': ['emergency', 'maternal'],
'expected_keywords': ['hemorrhage', 'postpartum', 'emergency'],
'min_confidence': 0.3
},
{
'query': 'Normal fetal heart rate monitoring procedures',
'content_types': ['procedure', 'maternal'],
'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'],
'min_confidence': 0.3
}
]
for case in test_cases:
with self.subTest(query=case['query']):
response = self.rag_system.query(
case['query'],
content_types=case['content_types']
)
# Response quality checks
self.assertGreater(len(response.sources), 0)
self.assertGreaterEqual(response.confidence, case['min_confidence'])
# Check if expected keywords appear in answer or sources
combined_text = response.answer.lower()
if response.sources:
combined_text += ' ' + ' '.join([s.content.lower() for s in response.sources])
keyword_found = any(keyword in combined_text for keyword in case['expected_keywords'])
self.assertTrue(keyword_found,
f"No expected keywords found for query: {case['query']}")
def test_response_metadata(self):
"""Test response metadata is populated correctly"""
query = "What are the signs of preeclampsia?"
response = self.rag_system.query(query)
# Check metadata fields
required_fields = ['num_sources', 'avg_relevance', 'content_types', 'high_importance_sources']
for field in required_fields:
self.assertIn(field, response.metadata)
# Validate metadata values
self.assertIsInstance(response.metadata['num_sources'], int)
self.assertGreaterEqual(response.metadata['num_sources'], 0)
self.assertIsInstance(response.metadata['avg_relevance'], float)
self.assertIsInstance(response.metadata['content_types'], list)
self.assertIsInstance(response.metadata['high_importance_sources'], int)
def test_confidence_scoring(self):
"""Test confidence scoring mechanism"""
# High-confidence query (should match well)
high_conf_query = "magnesium sulfate for preeclampsia"
high_response = self.rag_system.query(high_conf_query)
# Low-confidence query (less specific)
low_conf_query = "medical procedures in general"
low_response = self.rag_system.query(low_conf_query)
# Confidence should be higher for more specific medical queries
self.assertGreaterEqual(high_response.confidence, 0.3)
# Both should have valid confidence scores
self.assertGreaterEqual(high_response.confidence, 0.0)
self.assertLessEqual(high_response.confidence, 1.0)
self.assertGreaterEqual(low_response.confidence, 0.0)
self.assertLessEqual(low_response.confidence, 1.0)
def test_performance_metrics(self):
"""Test RAG system performance"""
query = "What is normal labor management?"
# Measure response time
start_time = time.time()
response = self.rag_system.query(query)
actual_time = time.time() - start_time
# Response should be fast (under 2 seconds for mock LLM)
self.assertLess(response.response_time, 2.0)
self.assertLess(actual_time, 3.0)
# Should return relevant sources
self.assertGreater(len(response.sources), 0)
self.assertLessEqual(len(response.sources), 10) # Should be reasonable number
def test_batch_query_processing(self):
"""Test batch query processing functionality"""
queries = [
"What is magnesium sulfate used for?",
"How to manage labor complications?",
"Normal fetal heart rate ranges"
]
responses = self.rag_system.batch_query(queries)
# Should return same number of responses
self.assertEqual(len(responses), len(queries))
# Each response should be valid
for i, response in enumerate(responses):
self.assertIsInstance(response, QueryResponse)
self.assertEqual(response.query, queries[i])
self.assertGreater(len(response.answer), 0)
def test_context_preparation(self):
"""Test context preparation from search results"""
query = "preeclampsia management guidelines"
response = self.rag_system.query(query)
# Should have sources for context
if response.sources:
# Context should be prepared from these sources
self.assertGreater(len(response.answer), 20) # Reasonable answer length
# Mock LLM should include safety disclaimer
self.assertIn("healthcare", response.answer.lower())
def test_error_handling(self):
"""Test error handling for edge cases"""
# Empty query
empty_response = self.rag_system.query("")
self.assertIsInstance(empty_response, QueryResponse)
self.assertIsInstance(empty_response.answer, str)
# Very long query
long_query = "What is the management protocol for " + "complicated " * 100 + "pregnancy cases?"
long_response = self.rag_system.query(long_query)
self.assertIsInstance(long_response, QueryResponse)
# Special characters query
special_query = "What is the dosage for mg++ and other electrolytes?"
special_response = self.rag_system.query(special_query)
self.assertIsInstance(special_response, QueryResponse)
def test_medical_safety_responses(self):
"""Test that responses include appropriate medical safety disclaimers"""
queries = [
"What medication should I take for preeclampsia?",
"How much magnesium sulfate should I give?",
"What should I do for emergency bleeding?"
]
for query in queries:
with self.subTest(query=query):
response = self.rag_system.query(query)
# Should include medical safety language
safety_terms = ['consult', 'healthcare', 'professional', 'medical']
answer_lower = response.answer.lower()
safety_found = any(term in answer_lower for term in safety_terms)
self.assertTrue(safety_found,
f"No safety disclaimer found in response to: {query}")
def test_system_statistics(self):
"""Test system statistics functionality"""
stats = self.rag_system.get_system_stats()
# Check required fields
required_fields = ['vector_store', 'rag_config', 'status']
for field in required_fields:
self.assertIn(field, stats)
# Check vector store stats
self.assertIn('total_chunks', stats['vector_store'])
self.assertGreater(stats['vector_store']['total_chunks'], 0)
# Check RAG config
self.assertIn('default_k', stats['rag_config'])
self.assertIn('llm_type', stats['rag_config'])
self.assertEqual(stats['rag_config']['llm_type'], 'mock')
def run_comprehensive_rag_tests():
"""Run all RAG tests with detailed reporting"""
print("π§ͺ Running Comprehensive RAG System Tests...")
print("=" * 60)
# Create test suite
loader = unittest.TestLoader()
suite = loader.loadTestsFromTestCase(TestMaternalHealthRAG)
# Run tests with detailed output
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Print summary
print("\n" + "=" * 60)
print("π RAG TEST SUMMARY:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
if result.wasSuccessful():
print("β
ALL RAG TESTS PASSED! RAG system is production-ready.")
print("\nπ Key Validations Completed:")
print(" β
RAG system initialization")
print(" β
Query processing with medical context")
print(" β
Confidence scoring and metadata")
print(" β
Performance under 2 seconds")
print(" β
Batch query processing")
print(" β
Medical safety disclaimers")
print(" β
Error handling and edge cases")
else:
print("β Some RAG tests failed. Check output above for details.")
if result.failures:
print("\nFailures:")
for test, traceback in result.failures:
lines = traceback.strip().split('\n')
error_line = lines[-1] if lines else "Unknown failure"
print(f" - {test}: {error_line}")
if result.errors:
print("\nErrors:")
for test, traceback in result.errors:
lines = traceback.strip().split('\n')
error_line = lines[-1] if lines else "Unknown error"
print(f" - {test}: {error_line}")
return result.wasSuccessful()
if __name__ == "__main__":
success = run_comprehensive_rag_tests()
exit(0 if success else 1) |