Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Test Suite for Maternal Health RAG System | |
Validates complete RAG pipeline including vector retrieval and response generation | |
""" | |
import unittest | |
import time | |
from typing import List, Dict, Any | |
from pathlib import Path | |
from maternal_health_rag import MaternalHealthRAG, QueryResponse | |
class TestMaternalHealthRAG(unittest.TestCase): | |
"""Test suite for RAG system functionality""" | |
def setUpClass(cls): | |
"""Set up test environment""" | |
print("π Initializing RAG system for testing...") | |
cls.rag_system = MaternalHealthRAG(use_mock_llm=True) | |
def test_rag_system_initialization(self): | |
"""Test RAG system initializes correctly""" | |
self.assertIsNotNone(self.rag_system.vector_store) | |
self.assertIsNotNone(self.rag_system.llm) | |
self.assertIsNotNone(self.rag_system.rag_chain) | |
# Check system status | |
stats = self.rag_system.get_system_stats() | |
self.assertEqual(stats['status'], 'initialized') | |
self.assertGreater(stats['vector_store']['total_chunks'], 0) | |
def test_basic_query_processing(self): | |
"""Test basic query processing functionality""" | |
query = "What is magnesium sulfate used for?" | |
response = self.rag_system.query(query) | |
# Basic response validation | |
self.assertIsInstance(response, QueryResponse) | |
self.assertEqual(response.query, query) | |
self.assertIsInstance(response.answer, str) | |
self.assertGreater(len(response.answer), 0) | |
self.assertGreaterEqual(response.confidence, 0.0) | |
self.assertLessEqual(response.confidence, 1.0) | |
self.assertGreater(response.response_time, 0.0) | |
def test_medical_context_queries(self): | |
"""Test queries with specific medical context""" | |
test_cases = [ | |
{ | |
'query': 'What is the dosage for magnesium sulfate in preeclampsia?', | |
'content_types': ['dosage', 'emergency'], | |
'expected_keywords': ['magnesium', 'dosage', 'preeclampsia'], | |
'min_confidence': 0.3 | |
}, | |
{ | |
'query': 'How to manage postpartum hemorrhage emergency?', | |
'content_types': ['emergency', 'maternal'], | |
'expected_keywords': ['hemorrhage', 'postpartum', 'emergency'], | |
'min_confidence': 0.3 | |
}, | |
{ | |
'query': 'Normal fetal heart rate monitoring procedures', | |
'content_types': ['procedure', 'maternal'], | |
'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'], | |
'min_confidence': 0.3 | |
} | |
] | |
for case in test_cases: | |
with self.subTest(query=case['query']): | |
response = self.rag_system.query( | |
case['query'], | |
content_types=case['content_types'] | |
) | |
# Response quality checks | |
self.assertGreater(len(response.sources), 0) | |
self.assertGreaterEqual(response.confidence, case['min_confidence']) | |
# Check if expected keywords appear in answer or sources | |
combined_text = response.answer.lower() | |
if response.sources: | |
combined_text += ' ' + ' '.join([s.content.lower() for s in response.sources]) | |
keyword_found = any(keyword in combined_text for keyword in case['expected_keywords']) | |
self.assertTrue(keyword_found, | |
f"No expected keywords found for query: {case['query']}") | |
def test_response_metadata(self): | |
"""Test response metadata is populated correctly""" | |
query = "What are the signs of preeclampsia?" | |
response = self.rag_system.query(query) | |
# Check metadata fields | |
required_fields = ['num_sources', 'avg_relevance', 'content_types', 'high_importance_sources'] | |
for field in required_fields: | |
self.assertIn(field, response.metadata) | |
# Validate metadata values | |
self.assertIsInstance(response.metadata['num_sources'], int) | |
self.assertGreaterEqual(response.metadata['num_sources'], 0) | |
self.assertIsInstance(response.metadata['avg_relevance'], float) | |
self.assertIsInstance(response.metadata['content_types'], list) | |
self.assertIsInstance(response.metadata['high_importance_sources'], int) | |
def test_confidence_scoring(self): | |
"""Test confidence scoring mechanism""" | |
# High-confidence query (should match well) | |
high_conf_query = "magnesium sulfate for preeclampsia" | |
high_response = self.rag_system.query(high_conf_query) | |
# Low-confidence query (less specific) | |
low_conf_query = "medical procedures in general" | |
low_response = self.rag_system.query(low_conf_query) | |
# Confidence should be higher for more specific medical queries | |
self.assertGreaterEqual(high_response.confidence, 0.3) | |
# Both should have valid confidence scores | |
self.assertGreaterEqual(high_response.confidence, 0.0) | |
self.assertLessEqual(high_response.confidence, 1.0) | |
self.assertGreaterEqual(low_response.confidence, 0.0) | |
self.assertLessEqual(low_response.confidence, 1.0) | |
def test_performance_metrics(self): | |
"""Test RAG system performance""" | |
query = "What is normal labor management?" | |
# Measure response time | |
start_time = time.time() | |
response = self.rag_system.query(query) | |
actual_time = time.time() - start_time | |
# Response should be fast (under 2 seconds for mock LLM) | |
self.assertLess(response.response_time, 2.0) | |
self.assertLess(actual_time, 3.0) | |
# Should return relevant sources | |
self.assertGreater(len(response.sources), 0) | |
self.assertLessEqual(len(response.sources), 10) # Should be reasonable number | |
def test_batch_query_processing(self): | |
"""Test batch query processing functionality""" | |
queries = [ | |
"What is magnesium sulfate used for?", | |
"How to manage labor complications?", | |
"Normal fetal heart rate ranges" | |
] | |
responses = self.rag_system.batch_query(queries) | |
# Should return same number of responses | |
self.assertEqual(len(responses), len(queries)) | |
# Each response should be valid | |
for i, response in enumerate(responses): | |
self.assertIsInstance(response, QueryResponse) | |
self.assertEqual(response.query, queries[i]) | |
self.assertGreater(len(response.answer), 0) | |
def test_context_preparation(self): | |
"""Test context preparation from search results""" | |
query = "preeclampsia management guidelines" | |
response = self.rag_system.query(query) | |
# Should have sources for context | |
if response.sources: | |
# Context should be prepared from these sources | |
self.assertGreater(len(response.answer), 20) # Reasonable answer length | |
# Mock LLM should include safety disclaimer | |
self.assertIn("healthcare", response.answer.lower()) | |
def test_error_handling(self): | |
"""Test error handling for edge cases""" | |
# Empty query | |
empty_response = self.rag_system.query("") | |
self.assertIsInstance(empty_response, QueryResponse) | |
self.assertIsInstance(empty_response.answer, str) | |
# Very long query | |
long_query = "What is the management protocol for " + "complicated " * 100 + "pregnancy cases?" | |
long_response = self.rag_system.query(long_query) | |
self.assertIsInstance(long_response, QueryResponse) | |
# Special characters query | |
special_query = "What is the dosage for mg++ and other electrolytes?" | |
special_response = self.rag_system.query(special_query) | |
self.assertIsInstance(special_response, QueryResponse) | |
def test_medical_safety_responses(self): | |
"""Test that responses include appropriate medical safety disclaimers""" | |
queries = [ | |
"What medication should I take for preeclampsia?", | |
"How much magnesium sulfate should I give?", | |
"What should I do for emergency bleeding?" | |
] | |
for query in queries: | |
with self.subTest(query=query): | |
response = self.rag_system.query(query) | |
# Should include medical safety language | |
safety_terms = ['consult', 'healthcare', 'professional', 'medical'] | |
answer_lower = response.answer.lower() | |
safety_found = any(term in answer_lower for term in safety_terms) | |
self.assertTrue(safety_found, | |
f"No safety disclaimer found in response to: {query}") | |
def test_system_statistics(self): | |
"""Test system statistics functionality""" | |
stats = self.rag_system.get_system_stats() | |
# Check required fields | |
required_fields = ['vector_store', 'rag_config', 'status'] | |
for field in required_fields: | |
self.assertIn(field, stats) | |
# Check vector store stats | |
self.assertIn('total_chunks', stats['vector_store']) | |
self.assertGreater(stats['vector_store']['total_chunks'], 0) | |
# Check RAG config | |
self.assertIn('default_k', stats['rag_config']) | |
self.assertIn('llm_type', stats['rag_config']) | |
self.assertEqual(stats['rag_config']['llm_type'], 'mock') | |
def run_comprehensive_rag_tests(): | |
"""Run all RAG tests with detailed reporting""" | |
print("π§ͺ Running Comprehensive RAG System Tests...") | |
print("=" * 60) | |
# Create test suite | |
loader = unittest.TestLoader() | |
suite = loader.loadTestsFromTestCase(TestMaternalHealthRAG) | |
# Run tests with detailed output | |
runner = unittest.TextTestRunner(verbosity=2) | |
result = runner.run(suite) | |
# Print summary | |
print("\n" + "=" * 60) | |
print("π RAG TEST SUMMARY:") | |
print(f" Tests run: {result.testsRun}") | |
print(f" Failures: {len(result.failures)}") | |
print(f" Errors: {len(result.errors)}") | |
if result.wasSuccessful(): | |
print("β ALL RAG TESTS PASSED! RAG system is production-ready.") | |
print("\nπ Key Validations Completed:") | |
print(" β RAG system initialization") | |
print(" β Query processing with medical context") | |
print(" β Confidence scoring and metadata") | |
print(" β Performance under 2 seconds") | |
print(" β Batch query processing") | |
print(" β Medical safety disclaimers") | |
print(" β Error handling and edge cases") | |
else: | |
print("β Some RAG tests failed. Check output above for details.") | |
if result.failures: | |
print("\nFailures:") | |
for test, traceback in result.failures: | |
lines = traceback.strip().split('\n') | |
error_line = lines[-1] if lines else "Unknown failure" | |
print(f" - {test}: {error_line}") | |
if result.errors: | |
print("\nErrors:") | |
for test, traceback in result.errors: | |
lines = traceback.strip().split('\n') | |
error_line = lines[-1] if lines else "Unknown error" | |
print(f" - {test}: {error_line}") | |
return result.wasSuccessful() | |
if __name__ == "__main__": | |
success = run_comprehensive_rag_tests() | |
exit(0 if success else 1) |