vedaMD / src /test_rag_system.py
sniro23's picture
Initial commit without binary files
19aaa42
#!/usr/bin/env python3
"""
Test Suite for Maternal Health RAG System
Validates complete RAG pipeline including vector retrieval and response generation
"""
import unittest
import time
from typing import List, Dict, Any
from pathlib import Path
from maternal_health_rag import MaternalHealthRAG, QueryResponse
class TestMaternalHealthRAG(unittest.TestCase):
"""Test suite for RAG system functionality"""
@classmethod
def setUpClass(cls):
"""Set up test environment"""
print("πŸš€ Initializing RAG system for testing...")
cls.rag_system = MaternalHealthRAG(use_mock_llm=True)
def test_rag_system_initialization(self):
"""Test RAG system initializes correctly"""
self.assertIsNotNone(self.rag_system.vector_store)
self.assertIsNotNone(self.rag_system.llm)
self.assertIsNotNone(self.rag_system.rag_chain)
# Check system status
stats = self.rag_system.get_system_stats()
self.assertEqual(stats['status'], 'initialized')
self.assertGreater(stats['vector_store']['total_chunks'], 0)
def test_basic_query_processing(self):
"""Test basic query processing functionality"""
query = "What is magnesium sulfate used for?"
response = self.rag_system.query(query)
# Basic response validation
self.assertIsInstance(response, QueryResponse)
self.assertEqual(response.query, query)
self.assertIsInstance(response.answer, str)
self.assertGreater(len(response.answer), 0)
self.assertGreaterEqual(response.confidence, 0.0)
self.assertLessEqual(response.confidence, 1.0)
self.assertGreater(response.response_time, 0.0)
def test_medical_context_queries(self):
"""Test queries with specific medical context"""
test_cases = [
{
'query': 'What is the dosage for magnesium sulfate in preeclampsia?',
'content_types': ['dosage', 'emergency'],
'expected_keywords': ['magnesium', 'dosage', 'preeclampsia'],
'min_confidence': 0.3
},
{
'query': 'How to manage postpartum hemorrhage emergency?',
'content_types': ['emergency', 'maternal'],
'expected_keywords': ['hemorrhage', 'postpartum', 'emergency'],
'min_confidence': 0.3
},
{
'query': 'Normal fetal heart rate monitoring procedures',
'content_types': ['procedure', 'maternal'],
'expected_keywords': ['fetal', 'heart', 'rate', 'monitoring'],
'min_confidence': 0.3
}
]
for case in test_cases:
with self.subTest(query=case['query']):
response = self.rag_system.query(
case['query'],
content_types=case['content_types']
)
# Response quality checks
self.assertGreater(len(response.sources), 0)
self.assertGreaterEqual(response.confidence, case['min_confidence'])
# Check if expected keywords appear in answer or sources
combined_text = response.answer.lower()
if response.sources:
combined_text += ' ' + ' '.join([s.content.lower() for s in response.sources])
keyword_found = any(keyword in combined_text for keyword in case['expected_keywords'])
self.assertTrue(keyword_found,
f"No expected keywords found for query: {case['query']}")
def test_response_metadata(self):
"""Test response metadata is populated correctly"""
query = "What are the signs of preeclampsia?"
response = self.rag_system.query(query)
# Check metadata fields
required_fields = ['num_sources', 'avg_relevance', 'content_types', 'high_importance_sources']
for field in required_fields:
self.assertIn(field, response.metadata)
# Validate metadata values
self.assertIsInstance(response.metadata['num_sources'], int)
self.assertGreaterEqual(response.metadata['num_sources'], 0)
self.assertIsInstance(response.metadata['avg_relevance'], float)
self.assertIsInstance(response.metadata['content_types'], list)
self.assertIsInstance(response.metadata['high_importance_sources'], int)
def test_confidence_scoring(self):
"""Test confidence scoring mechanism"""
# High-confidence query (should match well)
high_conf_query = "magnesium sulfate for preeclampsia"
high_response = self.rag_system.query(high_conf_query)
# Low-confidence query (less specific)
low_conf_query = "medical procedures in general"
low_response = self.rag_system.query(low_conf_query)
# Confidence should be higher for more specific medical queries
self.assertGreaterEqual(high_response.confidence, 0.3)
# Both should have valid confidence scores
self.assertGreaterEqual(high_response.confidence, 0.0)
self.assertLessEqual(high_response.confidence, 1.0)
self.assertGreaterEqual(low_response.confidence, 0.0)
self.assertLessEqual(low_response.confidence, 1.0)
def test_performance_metrics(self):
"""Test RAG system performance"""
query = "What is normal labor management?"
# Measure response time
start_time = time.time()
response = self.rag_system.query(query)
actual_time = time.time() - start_time
# Response should be fast (under 2 seconds for mock LLM)
self.assertLess(response.response_time, 2.0)
self.assertLess(actual_time, 3.0)
# Should return relevant sources
self.assertGreater(len(response.sources), 0)
self.assertLessEqual(len(response.sources), 10) # Should be reasonable number
def test_batch_query_processing(self):
"""Test batch query processing functionality"""
queries = [
"What is magnesium sulfate used for?",
"How to manage labor complications?",
"Normal fetal heart rate ranges"
]
responses = self.rag_system.batch_query(queries)
# Should return same number of responses
self.assertEqual(len(responses), len(queries))
# Each response should be valid
for i, response in enumerate(responses):
self.assertIsInstance(response, QueryResponse)
self.assertEqual(response.query, queries[i])
self.assertGreater(len(response.answer), 0)
def test_context_preparation(self):
"""Test context preparation from search results"""
query = "preeclampsia management guidelines"
response = self.rag_system.query(query)
# Should have sources for context
if response.sources:
# Context should be prepared from these sources
self.assertGreater(len(response.answer), 20) # Reasonable answer length
# Mock LLM should include safety disclaimer
self.assertIn("healthcare", response.answer.lower())
def test_error_handling(self):
"""Test error handling for edge cases"""
# Empty query
empty_response = self.rag_system.query("")
self.assertIsInstance(empty_response, QueryResponse)
self.assertIsInstance(empty_response.answer, str)
# Very long query
long_query = "What is the management protocol for " + "complicated " * 100 + "pregnancy cases?"
long_response = self.rag_system.query(long_query)
self.assertIsInstance(long_response, QueryResponse)
# Special characters query
special_query = "What is the dosage for mg++ and other electrolytes?"
special_response = self.rag_system.query(special_query)
self.assertIsInstance(special_response, QueryResponse)
def test_medical_safety_responses(self):
"""Test that responses include appropriate medical safety disclaimers"""
queries = [
"What medication should I take for preeclampsia?",
"How much magnesium sulfate should I give?",
"What should I do for emergency bleeding?"
]
for query in queries:
with self.subTest(query=query):
response = self.rag_system.query(query)
# Should include medical safety language
safety_terms = ['consult', 'healthcare', 'professional', 'medical']
answer_lower = response.answer.lower()
safety_found = any(term in answer_lower for term in safety_terms)
self.assertTrue(safety_found,
f"No safety disclaimer found in response to: {query}")
def test_system_statistics(self):
"""Test system statistics functionality"""
stats = self.rag_system.get_system_stats()
# Check required fields
required_fields = ['vector_store', 'rag_config', 'status']
for field in required_fields:
self.assertIn(field, stats)
# Check vector store stats
self.assertIn('total_chunks', stats['vector_store'])
self.assertGreater(stats['vector_store']['total_chunks'], 0)
# Check RAG config
self.assertIn('default_k', stats['rag_config'])
self.assertIn('llm_type', stats['rag_config'])
self.assertEqual(stats['rag_config']['llm_type'], 'mock')
def run_comprehensive_rag_tests():
"""Run all RAG tests with detailed reporting"""
print("πŸ§ͺ Running Comprehensive RAG System Tests...")
print("=" * 60)
# Create test suite
loader = unittest.TestLoader()
suite = loader.loadTestsFromTestCase(TestMaternalHealthRAG)
# Run tests with detailed output
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Print summary
print("\n" + "=" * 60)
print("πŸ“Š RAG TEST SUMMARY:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
if result.wasSuccessful():
print("βœ… ALL RAG TESTS PASSED! RAG system is production-ready.")
print("\nπŸŽ‰ Key Validations Completed:")
print(" βœ… RAG system initialization")
print(" βœ… Query processing with medical context")
print(" βœ… Confidence scoring and metadata")
print(" βœ… Performance under 2 seconds")
print(" βœ… Batch query processing")
print(" βœ… Medical safety disclaimers")
print(" βœ… Error handling and edge cases")
else:
print("❌ Some RAG tests failed. Check output above for details.")
if result.failures:
print("\nFailures:")
for test, traceback in result.failures:
lines = traceback.strip().split('\n')
error_line = lines[-1] if lines else "Unknown failure"
print(f" - {test}: {error_line}")
if result.errors:
print("\nErrors:")
for test, traceback in result.errors:
lines = traceback.strip().split('\n')
error_line = lines[-1] if lines else "Unknown error"
print(f" - {test}: {error_line}")
return result.wasSuccessful()
if __name__ == "__main__":
success = run_comprehensive_rag_tests()
exit(0 if success else 1)